diff --git a/benchmarks/benchmark.c b/benchmarks/benchmark.c index d64e5c4ee..d8f98ea34 100644 --- a/benchmarks/benchmark.c +++ b/benchmarks/benchmark.c @@ -444,6 +444,270 @@ uint8_t mlkem_rej_uniform_table[] = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 // 255 }; +// Constant lookup table for ML-DSA rejection sampling. Matches the byte-list +// table in x86/proofs/mldsa_rej_uniform_table.ml (256 entries, 8 bytes each = +// 2048 bytes) interpreted as a uint64_t[256] table of VPERMD indices. + +uint8_t mldsa_rej_uniform_table[] = +{ + 0, 0, 0, 0, 0, 0, 0, 0, // 0 + 0, 0, 0, 0, 0, 0, 0, 0, // 1 + 1, 0, 0, 0, 0, 0, 0, 0, // 2 + 0, 1, 0, 0, 0, 0, 0, 0, // 3 + 2, 0, 0, 0, 0, 0, 0, 0, // 4 + 0, 2, 0, 0, 0, 0, 0, 0, // 5 + 1, 2, 0, 0, 0, 0, 0, 0, // 6 + 0, 1, 2, 0, 0, 0, 0, 0, // 7 + 3, 0, 0, 0, 0, 0, 0, 0, // 8 + 0, 3, 0, 0, 0, 0, 0, 0, // 9 + 1, 3, 0, 0, 0, 0, 0, 0, // 10 + 0, 1, 3, 0, 0, 0, 0, 0, // 11 + 2, 3, 0, 0, 0, 0, 0, 0, // 12 + 0, 2, 3, 0, 0, 0, 0, 0, // 13 + 1, 2, 3, 0, 0, 0, 0, 0, // 14 + 0, 1, 2, 3, 0, 0, 0, 0, // 15 + 4, 0, 0, 0, 0, 0, 0, 0, // 16 + 0, 4, 0, 0, 0, 0, 0, 0, // 17 + 1, 4, 0, 0, 0, 0, 0, 0, // 18 + 0, 1, 4, 0, 0, 0, 0, 0, // 19 + 2, 4, 0, 0, 0, 0, 0, 0, // 20 + 0, 2, 4, 0, 0, 0, 0, 0, // 21 + 1, 2, 4, 0, 0, 0, 0, 0, // 22 + 0, 1, 2, 4, 0, 0, 0, 0, // 23 + 3, 4, 0, 0, 0, 0, 0, 0, // 24 + 0, 3, 4, 0, 0, 0, 0, 0, // 25 + 1, 3, 4, 0, 0, 0, 0, 0, // 26 + 0, 1, 3, 4, 0, 0, 0, 0, // 27 + 2, 3, 4, 0, 0, 0, 0, 0, // 28 + 0, 2, 3, 4, 0, 0, 0, 0, // 29 + 1, 2, 3, 4, 0, 0, 0, 0, // 30 + 0, 1, 2, 3, 4, 0, 0, 0, // 31 + 5, 0, 0, 0, 0, 0, 0, 0, // 32 + 0, 5, 0, 0, 0, 0, 0, 0, // 33 + 1, 5, 0, 0, 0, 0, 0, 0, // 34 + 0, 1, 5, 0, 0, 0, 0, 0, // 35 + 2, 5, 0, 0, 0, 0, 0, 0, // 36 + 0, 2, 5, 0, 0, 0, 0, 0, // 37 + 1, 2, 5, 0, 0, 0, 0, 0, // 38 + 0, 1, 2, 5, 0, 0, 0, 0, // 39 + 3, 5, 0, 0, 0, 0, 0, 0, // 40 + 0, 3, 5, 0, 0, 0, 0, 0, // 41 + 1, 3, 5, 0, 0, 0, 0, 0, // 42 + 0, 1, 3, 5, 0, 0, 0, 0, // 43 + 2, 3, 5, 0, 0, 0, 0, 0, // 44 + 0, 2, 3, 5, 0, 0, 0, 0, // 45 + 1, 2, 3, 5, 0, 0, 0, 0, // 46 + 0, 1, 2, 3, 5, 0, 0, 0, // 47 + 4, 5, 0, 0, 0, 0, 0, 0, // 48 + 0, 4, 5, 0, 0, 0, 0, 0, // 49 + 1, 4, 5, 0, 0, 0, 0, 0, // 50 + 0, 1, 4, 5, 0, 0, 0, 0, // 51 + 2, 4, 5, 0, 0, 0, 0, 0, // 52 + 0, 2, 4, 5, 0, 0, 0, 0, // 53 + 1, 2, 4, 5, 0, 0, 0, 0, // 54 + 0, 1, 2, 4, 5, 0, 0, 0, // 55 + 3, 4, 5, 0, 0, 0, 0, 0, // 56 + 0, 3, 4, 5, 0, 0, 0, 0, // 57 + 1, 3, 4, 5, 0, 0, 0, 0, // 58 + 0, 1, 3, 4, 5, 0, 0, 0, // 59 + 2, 3, 4, 5, 0, 0, 0, 0, // 60 + 0, 2, 3, 4, 5, 0, 0, 0, // 61 + 1, 2, 3, 4, 5, 0, 0, 0, // 62 + 0, 1, 2, 3, 4, 5, 0, 0, // 63 + 6, 0, 0, 0, 0, 0, 0, 0, // 64 + 0, 6, 0, 0, 0, 0, 0, 0, // 65 + 1, 6, 0, 0, 0, 0, 0, 0, // 66 + 0, 1, 6, 0, 0, 0, 0, 0, // 67 + 2, 6, 0, 0, 0, 0, 0, 0, // 68 + 0, 2, 6, 0, 0, 0, 0, 0, // 69 + 1, 2, 6, 0, 0, 0, 0, 0, // 70 + 0, 1, 2, 6, 0, 0, 0, 0, // 71 + 3, 6, 0, 0, 0, 0, 0, 0, // 72 + 0, 3, 6, 0, 0, 0, 0, 0, // 73 + 1, 3, 6, 0, 0, 0, 0, 0, // 74 + 0, 1, 3, 6, 0, 0, 0, 0, // 75 + 2, 3, 6, 0, 0, 0, 0, 0, // 76 + 0, 2, 3, 6, 0, 0, 0, 0, // 77 + 1, 2, 3, 6, 0, 0, 0, 0, // 78 + 0, 1, 2, 3, 6, 0, 0, 0, // 79 + 4, 6, 0, 0, 0, 0, 0, 0, // 80 + 0, 4, 6, 0, 0, 0, 0, 0, // 81 + 1, 4, 6, 0, 0, 0, 0, 0, // 82 + 0, 1, 4, 6, 0, 0, 0, 0, // 83 + 2, 4, 6, 0, 0, 0, 0, 0, // 84 + 0, 2, 4, 6, 0, 0, 0, 0, // 85 + 1, 2, 4, 6, 0, 0, 0, 0, // 86 + 0, 1, 2, 4, 6, 0, 0, 0, // 87 + 3, 4, 6, 0, 0, 0, 0, 0, // 88 + 0, 3, 4, 6, 0, 0, 0, 0, // 89 + 1, 3, 4, 6, 0, 0, 0, 0, // 90 + 0, 1, 3, 4, 6, 0, 0, 0, // 91 + 2, 3, 4, 6, 0, 0, 0, 0, // 92 + 0, 2, 3, 4, 6, 0, 0, 0, // 93 + 1, 2, 3, 4, 6, 0, 0, 0, // 94 + 0, 1, 2, 3, 4, 6, 0, 0, // 95 + 5, 6, 0, 0, 0, 0, 0, 0, // 96 + 0, 5, 6, 0, 0, 0, 0, 0, // 97 + 1, 5, 6, 0, 0, 0, 0, 0, // 98 + 0, 1, 5, 6, 0, 0, 0, 0, // 99 + 2, 5, 6, 0, 0, 0, 0, 0, // 100 + 0, 2, 5, 6, 0, 0, 0, 0, // 101 + 1, 2, 5, 6, 0, 0, 0, 0, // 102 + 0, 1, 2, 5, 6, 0, 0, 0, // 103 + 3, 5, 6, 0, 0, 0, 0, 0, // 104 + 0, 3, 5, 6, 0, 0, 0, 0, // 105 + 1, 3, 5, 6, 0, 0, 0, 0, // 106 + 0, 1, 3, 5, 6, 0, 0, 0, // 107 + 2, 3, 5, 6, 0, 0, 0, 0, // 108 + 0, 2, 3, 5, 6, 0, 0, 0, // 109 + 1, 2, 3, 5, 6, 0, 0, 0, // 110 + 0, 1, 2, 3, 5, 6, 0, 0, // 111 + 4, 5, 6, 0, 0, 0, 0, 0, // 112 + 0, 4, 5, 6, 0, 0, 0, 0, // 113 + 1, 4, 5, 6, 0, 0, 0, 0, // 114 + 0, 1, 4, 5, 6, 0, 0, 0, // 115 + 2, 4, 5, 6, 0, 0, 0, 0, // 116 + 0, 2, 4, 5, 6, 0, 0, 0, // 117 + 1, 2, 4, 5, 6, 0, 0, 0, // 118 + 0, 1, 2, 4, 5, 6, 0, 0, // 119 + 3, 4, 5, 6, 0, 0, 0, 0, // 120 + 0, 3, 4, 5, 6, 0, 0, 0, // 121 + 1, 3, 4, 5, 6, 0, 0, 0, // 122 + 0, 1, 3, 4, 5, 6, 0, 0, // 123 + 2, 3, 4, 5, 6, 0, 0, 0, // 124 + 0, 2, 3, 4, 5, 6, 0, 0, // 125 + 1, 2, 3, 4, 5, 6, 0, 0, // 126 + 0, 1, 2, 3, 4, 5, 6, 0, // 127 + 7, 0, 0, 0, 0, 0, 0, 0, // 128 + 0, 7, 0, 0, 0, 0, 0, 0, // 129 + 1, 7, 0, 0, 0, 0, 0, 0, // 130 + 0, 1, 7, 0, 0, 0, 0, 0, // 131 + 2, 7, 0, 0, 0, 0, 0, 0, // 132 + 0, 2, 7, 0, 0, 0, 0, 0, // 133 + 1, 2, 7, 0, 0, 0, 0, 0, // 134 + 0, 1, 2, 7, 0, 0, 0, 0, // 135 + 3, 7, 0, 0, 0, 0, 0, 0, // 136 + 0, 3, 7, 0, 0, 0, 0, 0, // 137 + 1, 3, 7, 0, 0, 0, 0, 0, // 138 + 0, 1, 3, 7, 0, 0, 0, 0, // 139 + 2, 3, 7, 0, 0, 0, 0, 0, // 140 + 0, 2, 3, 7, 0, 0, 0, 0, // 141 + 1, 2, 3, 7, 0, 0, 0, 0, // 142 + 0, 1, 2, 3, 7, 0, 0, 0, // 143 + 4, 7, 0, 0, 0, 0, 0, 0, // 144 + 0, 4, 7, 0, 0, 0, 0, 0, // 145 + 1, 4, 7, 0, 0, 0, 0, 0, // 146 + 0, 1, 4, 7, 0, 0, 0, 0, // 147 + 2, 4, 7, 0, 0, 0, 0, 0, // 148 + 0, 2, 4, 7, 0, 0, 0, 0, // 149 + 1, 2, 4, 7, 0, 0, 0, 0, // 150 + 0, 1, 2, 4, 7, 0, 0, 0, // 151 + 3, 4, 7, 0, 0, 0, 0, 0, // 152 + 0, 3, 4, 7, 0, 0, 0, 0, // 153 + 1, 3, 4, 7, 0, 0, 0, 0, // 154 + 0, 1, 3, 4, 7, 0, 0, 0, // 155 + 2, 3, 4, 7, 0, 0, 0, 0, // 156 + 0, 2, 3, 4, 7, 0, 0, 0, // 157 + 1, 2, 3, 4, 7, 0, 0, 0, // 158 + 0, 1, 2, 3, 4, 7, 0, 0, // 159 + 5, 7, 0, 0, 0, 0, 0, 0, // 160 + 0, 5, 7, 0, 0, 0, 0, 0, // 161 + 1, 5, 7, 0, 0, 0, 0, 0, // 162 + 0, 1, 5, 7, 0, 0, 0, 0, // 163 + 2, 5, 7, 0, 0, 0, 0, 0, // 164 + 0, 2, 5, 7, 0, 0, 0, 0, // 165 + 1, 2, 5, 7, 0, 0, 0, 0, // 166 + 0, 1, 2, 5, 7, 0, 0, 0, // 167 + 3, 5, 7, 0, 0, 0, 0, 0, // 168 + 0, 3, 5, 7, 0, 0, 0, 0, // 169 + 1, 3, 5, 7, 0, 0, 0, 0, // 170 + 0, 1, 3, 5, 7, 0, 0, 0, // 171 + 2, 3, 5, 7, 0, 0, 0, 0, // 172 + 0, 2, 3, 5, 7, 0, 0, 0, // 173 + 1, 2, 3, 5, 7, 0, 0, 0, // 174 + 0, 1, 2, 3, 5, 7, 0, 0, // 175 + 4, 5, 7, 0, 0, 0, 0, 0, // 176 + 0, 4, 5, 7, 0, 0, 0, 0, // 177 + 1, 4, 5, 7, 0, 0, 0, 0, // 178 + 0, 1, 4, 5, 7, 0, 0, 0, // 179 + 2, 4, 5, 7, 0, 0, 0, 0, // 180 + 0, 2, 4, 5, 7, 0, 0, 0, // 181 + 1, 2, 4, 5, 7, 0, 0, 0, // 182 + 0, 1, 2, 4, 5, 7, 0, 0, // 183 + 3, 4, 5, 7, 0, 0, 0, 0, // 184 + 0, 3, 4, 5, 7, 0, 0, 0, // 185 + 1, 3, 4, 5, 7, 0, 0, 0, // 186 + 0, 1, 3, 4, 5, 7, 0, 0, // 187 + 2, 3, 4, 5, 7, 0, 0, 0, // 188 + 0, 2, 3, 4, 5, 7, 0, 0, // 189 + 1, 2, 3, 4, 5, 7, 0, 0, // 190 + 0, 1, 2, 3, 4, 5, 7, 0, // 191 + 6, 7, 0, 0, 0, 0, 0, 0, // 192 + 0, 6, 7, 0, 0, 0, 0, 0, // 193 + 1, 6, 7, 0, 0, 0, 0, 0, // 194 + 0, 1, 6, 7, 0, 0, 0, 0, // 195 + 2, 6, 7, 0, 0, 0, 0, 0, // 196 + 0, 2, 6, 7, 0, 0, 0, 0, // 197 + 1, 2, 6, 7, 0, 0, 0, 0, // 198 + 0, 1, 2, 6, 7, 0, 0, 0, // 199 + 3, 6, 7, 0, 0, 0, 0, 0, // 200 + 0, 3, 6, 7, 0, 0, 0, 0, // 201 + 1, 3, 6, 7, 0, 0, 0, 0, // 202 + 0, 1, 3, 6, 7, 0, 0, 0, // 203 + 2, 3, 6, 7, 0, 0, 0, 0, // 204 + 0, 2, 3, 6, 7, 0, 0, 0, // 205 + 1, 2, 3, 6, 7, 0, 0, 0, // 206 + 0, 1, 2, 3, 6, 7, 0, 0, // 207 + 4, 6, 7, 0, 0, 0, 0, 0, // 208 + 0, 4, 6, 7, 0, 0, 0, 0, // 209 + 1, 4, 6, 7, 0, 0, 0, 0, // 210 + 0, 1, 4, 6, 7, 0, 0, 0, // 211 + 2, 4, 6, 7, 0, 0, 0, 0, // 212 + 0, 2, 4, 6, 7, 0, 0, 0, // 213 + 1, 2, 4, 6, 7, 0, 0, 0, // 214 + 0, 1, 2, 4, 6, 7, 0, 0, // 215 + 3, 4, 6, 7, 0, 0, 0, 0, // 216 + 0, 3, 4, 6, 7, 0, 0, 0, // 217 + 1, 3, 4, 6, 7, 0, 0, 0, // 218 + 0, 1, 3, 4, 6, 7, 0, 0, // 219 + 2, 3, 4, 6, 7, 0, 0, 0, // 220 + 0, 2, 3, 4, 6, 7, 0, 0, // 221 + 1, 2, 3, 4, 6, 7, 0, 0, // 222 + 0, 1, 2, 3, 4, 6, 7, 0, // 223 + 5, 6, 7, 0, 0, 0, 0, 0, // 224 + 0, 5, 6, 7, 0, 0, 0, 0, // 225 + 1, 5, 6, 7, 0, 0, 0, 0, // 226 + 0, 1, 5, 6, 7, 0, 0, 0, // 227 + 2, 5, 6, 7, 0, 0, 0, 0, // 228 + 0, 2, 5, 6, 7, 0, 0, 0, // 229 + 1, 2, 5, 6, 7, 0, 0, 0, // 230 + 0, 1, 2, 5, 6, 7, 0, 0, // 231 + 3, 5, 6, 7, 0, 0, 0, 0, // 232 + 0, 3, 5, 6, 7, 0, 0, 0, // 233 + 1, 3, 5, 6, 7, 0, 0, 0, // 234 + 0, 1, 3, 5, 6, 7, 0, 0, // 235 + 2, 3, 5, 6, 7, 0, 0, 0, // 236 + 0, 2, 3, 5, 6, 7, 0, 0, // 237 + 1, 2, 3, 5, 6, 7, 0, 0, // 238 + 0, 1, 2, 3, 5, 6, 7, 0, // 239 + 4, 5, 6, 7, 0, 0, 0, 0, // 240 + 0, 4, 5, 6, 7, 0, 0, 0, // 241 + 1, 4, 5, 6, 7, 0, 0, 0, // 242 + 0, 1, 4, 5, 6, 7, 0, 0, // 243 + 2, 4, 5, 6, 7, 0, 0, 0, // 244 + 0, 2, 4, 5, 6, 7, 0, 0, // 245 + 1, 2, 4, 5, 6, 7, 0, 0, // 246 + 0, 1, 2, 4, 5, 6, 7, 0, // 247 + 3, 4, 5, 6, 7, 0, 0, 0, // 248 + 0, 3, 4, 5, 6, 7, 0, 0, // 249 + 1, 3, 4, 5, 6, 7, 0, 0, // 250 + 0, 1, 3, 4, 5, 6, 7, 0, // 251 + 2, 3, 4, 5, 6, 7, 0, 0, // 252 + 0, 2, 3, 4, 5, 6, 7, 0, // 253 + 1, 2, 3, 4, 5, 6, 7, 0, // 254 + 0, 1, 2, 3, 4, 5, 6, 7 // 255 +}; + // Wrappers round the functions to call uniformly void call_bignum_add__4_4(void) repeat(bignum_add(4,b0,4,b1,4,b2)) @@ -1112,6 +1376,7 @@ void call_mldsa_pointwise_acc_l4(void) repeat(mldsa_pointwise_acc_l4_x86((int32_ void call_mldsa_pointwise_acc_l5(void) repeat(mldsa_pointwise_acc_l5_x86((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2,mldsa_avx2_qdata)) void call_mldsa_pointwise_acc_l7(void) repeat(mldsa_pointwise_acc_l7_x86((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2,mldsa_avx2_qdata)) void call_mldsa_reduce(void) repeat(mldsa_reduce((int32_t*)b0)) +void call_mldsa_rej_uniform(void) repeat(mldsa_rej_uniform((int32_t*)b0,(uint8_t*)b1,(const uint64_t*)mldsa_rej_uniform_table)) void call_mlkem_frombytes(void) repeat(mlkem_frombytes((uint16_t*)b0,(int8_t*)b1)) void call_mlkem_intt(void) repeat(mlkem_intt_x86((int16_t*)b0,(int16_t*)b1)) @@ -1156,6 +1421,7 @@ void call_mldsa_pointwise_acc_l4(void) repeat(mldsa_pointwise_acc_l4((int32_t*)b void call_mldsa_pointwise_acc_l5(void) repeat(mldsa_pointwise_acc_l5((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) void call_mldsa_pointwise_acc_l7(void) repeat(mldsa_pointwise_acc_l7((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) void call_mldsa_reduce(void) {} +void call_mldsa_rej_uniform(void) {} void call_bignum_copy_row_from_table_8n__32_16(void) \ repeat(bignum_copy_row_from_table_8n(b0,b1,32,16,0)) @@ -1630,6 +1896,7 @@ int main(int argc, char *argv[]) timingtest(all,"mldsa_pointwise_acc_l5",call_mldsa_pointwise_acc_l5); timingtest(all,"mldsa_pointwise_acc_l7",call_mldsa_pointwise_acc_l7); timingtest(!arm,"mldsa_reduce",call_mldsa_reduce); + timingtest(!arm,"mldsa_rej_uniform",call_mldsa_rej_uniform); timingtest(bmi,"p256_montjadd",call_p256_montjadd); timingtest(all,"p256_montjadd_alt",call_p256_montjadd_alt); timingtest(bmi,"p256_montjdouble",call_p256_montjdouble); diff --git a/include/s2n-bignum-c89.h b/include/s2n-bignum-c89.h index 38f2c4893..d586eb736 100644 --- a/include/s2n-bignum-c89.h +++ b/include/s2n-bignum-c89.h @@ -1013,6 +1013,13 @@ extern void mldsa_ntt(int32_t a[256], const int32_t zetas[624]); /* Input a[256] (signed 32-bit words); output a[256] (signed 32-bit words) */ extern void mldsa_nttunpack(int32_t a[256]); +/* Uniform rejection sampling for ML-DSA: extract 23-bit coefficients from */ +/* 3-byte-packed input, keeping only those strictly less than q = 8380417. */ +/* Returns the number of accepted coefficients (at most 256). */ +/* Inputs buf[840] (uint8_t), table[256] (uint64_t lookup table); */ +/* output r[256] (int32_t). */ +extern uint32_t mldsa_rej_uniform(int32_t r[256], const uint8_t buf[840], const uint64_t table[256]); + /* Pointwise multiplication of polynomials in NTT domain (Montgomery form) for ML-DSA */ /* Inputs a[256], b[256] (signed 32-bit words); output r[256] (signed 32-bit words) */ extern void mldsa_pointwise(int32_t r[256], const int32_t a[256], const int32_t b[256]); diff --git a/include/s2n-bignum.h b/include/s2n-bignum.h index 2c9273750..f584421ce 100644 --- a/include/s2n-bignum.h +++ b/include/s2n-bignum.h @@ -1018,6 +1018,13 @@ extern void mldsa_ntt(int32_t a[S2N_BIGNUM_STATIC 256], const int32_t zetas[S2N_ // Input a[256] (signed 32-bit words); output a[256] (signed 32-bit words) extern void mldsa_nttunpack(int32_t a[S2N_BIGNUM_STATIC 256]); +// Uniform rejection sampling for ML-DSA: extract 23-bit coefficients from +// 3-byte-packed input, keeping only those strictly less than q = 8380417. +// Returns the number of accepted coefficients (at most 256). +// Inputs buf[840] (uint8_t), table[256] (uint64_t lookup table); +// output r[256] (int32_t). +extern uint32_t mldsa_rej_uniform(int32_t r[S2N_BIGNUM_STATIC 256], const uint8_t buf[S2N_BIGNUM_STATIC 840], const uint64_t table[S2N_BIGNUM_STATIC 256]); + // Pointwise multiplication of polynomials in NTT domain (Montgomery form) for ML-DSA // Inputs a[256], b[256] (signed 32-bit words); output r[256] (signed 32-bit words) extern void mldsa_pointwise(int32_t r[S2N_BIGNUM_STATIC 256], const int32_t a[S2N_BIGNUM_STATIC 256], const int32_t b[S2N_BIGNUM_STATIC 256]); diff --git a/tests/test.c b/tests/test.c index a545ceeac..01b8cbcdf 100644 --- a/tests/test.c +++ b/tests/test.c @@ -825,6 +825,270 @@ uint8_t mlkem_rej_uniform_table[] = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 // 255 }; +// Constant lookup table for ML-DSA rejection sampling. Matches the byte-list +// table in x86/proofs/mldsa_rej_uniform_table.ml (256 entries, 8 bytes each = +// 2048 bytes) interpreted as a uint64_t[256] table of VPERMD indices. + +uint8_t mldsa_rej_uniform_table[] = +{ + 0, 0, 0, 0, 0, 0, 0, 0, // 0 + 0, 0, 0, 0, 0, 0, 0, 0, // 1 + 1, 0, 0, 0, 0, 0, 0, 0, // 2 + 0, 1, 0, 0, 0, 0, 0, 0, // 3 + 2, 0, 0, 0, 0, 0, 0, 0, // 4 + 0, 2, 0, 0, 0, 0, 0, 0, // 5 + 1, 2, 0, 0, 0, 0, 0, 0, // 6 + 0, 1, 2, 0, 0, 0, 0, 0, // 7 + 3, 0, 0, 0, 0, 0, 0, 0, // 8 + 0, 3, 0, 0, 0, 0, 0, 0, // 9 + 1, 3, 0, 0, 0, 0, 0, 0, // 10 + 0, 1, 3, 0, 0, 0, 0, 0, // 11 + 2, 3, 0, 0, 0, 0, 0, 0, // 12 + 0, 2, 3, 0, 0, 0, 0, 0, // 13 + 1, 2, 3, 0, 0, 0, 0, 0, // 14 + 0, 1, 2, 3, 0, 0, 0, 0, // 15 + 4, 0, 0, 0, 0, 0, 0, 0, // 16 + 0, 4, 0, 0, 0, 0, 0, 0, // 17 + 1, 4, 0, 0, 0, 0, 0, 0, // 18 + 0, 1, 4, 0, 0, 0, 0, 0, // 19 + 2, 4, 0, 0, 0, 0, 0, 0, // 20 + 0, 2, 4, 0, 0, 0, 0, 0, // 21 + 1, 2, 4, 0, 0, 0, 0, 0, // 22 + 0, 1, 2, 4, 0, 0, 0, 0, // 23 + 3, 4, 0, 0, 0, 0, 0, 0, // 24 + 0, 3, 4, 0, 0, 0, 0, 0, // 25 + 1, 3, 4, 0, 0, 0, 0, 0, // 26 + 0, 1, 3, 4, 0, 0, 0, 0, // 27 + 2, 3, 4, 0, 0, 0, 0, 0, // 28 + 0, 2, 3, 4, 0, 0, 0, 0, // 29 + 1, 2, 3, 4, 0, 0, 0, 0, // 30 + 0, 1, 2, 3, 4, 0, 0, 0, // 31 + 5, 0, 0, 0, 0, 0, 0, 0, // 32 + 0, 5, 0, 0, 0, 0, 0, 0, // 33 + 1, 5, 0, 0, 0, 0, 0, 0, // 34 + 0, 1, 5, 0, 0, 0, 0, 0, // 35 + 2, 5, 0, 0, 0, 0, 0, 0, // 36 + 0, 2, 5, 0, 0, 0, 0, 0, // 37 + 1, 2, 5, 0, 0, 0, 0, 0, // 38 + 0, 1, 2, 5, 0, 0, 0, 0, // 39 + 3, 5, 0, 0, 0, 0, 0, 0, // 40 + 0, 3, 5, 0, 0, 0, 0, 0, // 41 + 1, 3, 5, 0, 0, 0, 0, 0, // 42 + 0, 1, 3, 5, 0, 0, 0, 0, // 43 + 2, 3, 5, 0, 0, 0, 0, 0, // 44 + 0, 2, 3, 5, 0, 0, 0, 0, // 45 + 1, 2, 3, 5, 0, 0, 0, 0, // 46 + 0, 1, 2, 3, 5, 0, 0, 0, // 47 + 4, 5, 0, 0, 0, 0, 0, 0, // 48 + 0, 4, 5, 0, 0, 0, 0, 0, // 49 + 1, 4, 5, 0, 0, 0, 0, 0, // 50 + 0, 1, 4, 5, 0, 0, 0, 0, // 51 + 2, 4, 5, 0, 0, 0, 0, 0, // 52 + 0, 2, 4, 5, 0, 0, 0, 0, // 53 + 1, 2, 4, 5, 0, 0, 0, 0, // 54 + 0, 1, 2, 4, 5, 0, 0, 0, // 55 + 3, 4, 5, 0, 0, 0, 0, 0, // 56 + 0, 3, 4, 5, 0, 0, 0, 0, // 57 + 1, 3, 4, 5, 0, 0, 0, 0, // 58 + 0, 1, 3, 4, 5, 0, 0, 0, // 59 + 2, 3, 4, 5, 0, 0, 0, 0, // 60 + 0, 2, 3, 4, 5, 0, 0, 0, // 61 + 1, 2, 3, 4, 5, 0, 0, 0, // 62 + 0, 1, 2, 3, 4, 5, 0, 0, // 63 + 6, 0, 0, 0, 0, 0, 0, 0, // 64 + 0, 6, 0, 0, 0, 0, 0, 0, // 65 + 1, 6, 0, 0, 0, 0, 0, 0, // 66 + 0, 1, 6, 0, 0, 0, 0, 0, // 67 + 2, 6, 0, 0, 0, 0, 0, 0, // 68 + 0, 2, 6, 0, 0, 0, 0, 0, // 69 + 1, 2, 6, 0, 0, 0, 0, 0, // 70 + 0, 1, 2, 6, 0, 0, 0, 0, // 71 + 3, 6, 0, 0, 0, 0, 0, 0, // 72 + 0, 3, 6, 0, 0, 0, 0, 0, // 73 + 1, 3, 6, 0, 0, 0, 0, 0, // 74 + 0, 1, 3, 6, 0, 0, 0, 0, // 75 + 2, 3, 6, 0, 0, 0, 0, 0, // 76 + 0, 2, 3, 6, 0, 0, 0, 0, // 77 + 1, 2, 3, 6, 0, 0, 0, 0, // 78 + 0, 1, 2, 3, 6, 0, 0, 0, // 79 + 4, 6, 0, 0, 0, 0, 0, 0, // 80 + 0, 4, 6, 0, 0, 0, 0, 0, // 81 + 1, 4, 6, 0, 0, 0, 0, 0, // 82 + 0, 1, 4, 6, 0, 0, 0, 0, // 83 + 2, 4, 6, 0, 0, 0, 0, 0, // 84 + 0, 2, 4, 6, 0, 0, 0, 0, // 85 + 1, 2, 4, 6, 0, 0, 0, 0, // 86 + 0, 1, 2, 4, 6, 0, 0, 0, // 87 + 3, 4, 6, 0, 0, 0, 0, 0, // 88 + 0, 3, 4, 6, 0, 0, 0, 0, // 89 + 1, 3, 4, 6, 0, 0, 0, 0, // 90 + 0, 1, 3, 4, 6, 0, 0, 0, // 91 + 2, 3, 4, 6, 0, 0, 0, 0, // 92 + 0, 2, 3, 4, 6, 0, 0, 0, // 93 + 1, 2, 3, 4, 6, 0, 0, 0, // 94 + 0, 1, 2, 3, 4, 6, 0, 0, // 95 + 5, 6, 0, 0, 0, 0, 0, 0, // 96 + 0, 5, 6, 0, 0, 0, 0, 0, // 97 + 1, 5, 6, 0, 0, 0, 0, 0, // 98 + 0, 1, 5, 6, 0, 0, 0, 0, // 99 + 2, 5, 6, 0, 0, 0, 0, 0, // 100 + 0, 2, 5, 6, 0, 0, 0, 0, // 101 + 1, 2, 5, 6, 0, 0, 0, 0, // 102 + 0, 1, 2, 5, 6, 0, 0, 0, // 103 + 3, 5, 6, 0, 0, 0, 0, 0, // 104 + 0, 3, 5, 6, 0, 0, 0, 0, // 105 + 1, 3, 5, 6, 0, 0, 0, 0, // 106 + 0, 1, 3, 5, 6, 0, 0, 0, // 107 + 2, 3, 5, 6, 0, 0, 0, 0, // 108 + 0, 2, 3, 5, 6, 0, 0, 0, // 109 + 1, 2, 3, 5, 6, 0, 0, 0, // 110 + 0, 1, 2, 3, 5, 6, 0, 0, // 111 + 4, 5, 6, 0, 0, 0, 0, 0, // 112 + 0, 4, 5, 6, 0, 0, 0, 0, // 113 + 1, 4, 5, 6, 0, 0, 0, 0, // 114 + 0, 1, 4, 5, 6, 0, 0, 0, // 115 + 2, 4, 5, 6, 0, 0, 0, 0, // 116 + 0, 2, 4, 5, 6, 0, 0, 0, // 117 + 1, 2, 4, 5, 6, 0, 0, 0, // 118 + 0, 1, 2, 4, 5, 6, 0, 0, // 119 + 3, 4, 5, 6, 0, 0, 0, 0, // 120 + 0, 3, 4, 5, 6, 0, 0, 0, // 121 + 1, 3, 4, 5, 6, 0, 0, 0, // 122 + 0, 1, 3, 4, 5, 6, 0, 0, // 123 + 2, 3, 4, 5, 6, 0, 0, 0, // 124 + 0, 2, 3, 4, 5, 6, 0, 0, // 125 + 1, 2, 3, 4, 5, 6, 0, 0, // 126 + 0, 1, 2, 3, 4, 5, 6, 0, // 127 + 7, 0, 0, 0, 0, 0, 0, 0, // 128 + 0, 7, 0, 0, 0, 0, 0, 0, // 129 + 1, 7, 0, 0, 0, 0, 0, 0, // 130 + 0, 1, 7, 0, 0, 0, 0, 0, // 131 + 2, 7, 0, 0, 0, 0, 0, 0, // 132 + 0, 2, 7, 0, 0, 0, 0, 0, // 133 + 1, 2, 7, 0, 0, 0, 0, 0, // 134 + 0, 1, 2, 7, 0, 0, 0, 0, // 135 + 3, 7, 0, 0, 0, 0, 0, 0, // 136 + 0, 3, 7, 0, 0, 0, 0, 0, // 137 + 1, 3, 7, 0, 0, 0, 0, 0, // 138 + 0, 1, 3, 7, 0, 0, 0, 0, // 139 + 2, 3, 7, 0, 0, 0, 0, 0, // 140 + 0, 2, 3, 7, 0, 0, 0, 0, // 141 + 1, 2, 3, 7, 0, 0, 0, 0, // 142 + 0, 1, 2, 3, 7, 0, 0, 0, // 143 + 4, 7, 0, 0, 0, 0, 0, 0, // 144 + 0, 4, 7, 0, 0, 0, 0, 0, // 145 + 1, 4, 7, 0, 0, 0, 0, 0, // 146 + 0, 1, 4, 7, 0, 0, 0, 0, // 147 + 2, 4, 7, 0, 0, 0, 0, 0, // 148 + 0, 2, 4, 7, 0, 0, 0, 0, // 149 + 1, 2, 4, 7, 0, 0, 0, 0, // 150 + 0, 1, 2, 4, 7, 0, 0, 0, // 151 + 3, 4, 7, 0, 0, 0, 0, 0, // 152 + 0, 3, 4, 7, 0, 0, 0, 0, // 153 + 1, 3, 4, 7, 0, 0, 0, 0, // 154 + 0, 1, 3, 4, 7, 0, 0, 0, // 155 + 2, 3, 4, 7, 0, 0, 0, 0, // 156 + 0, 2, 3, 4, 7, 0, 0, 0, // 157 + 1, 2, 3, 4, 7, 0, 0, 0, // 158 + 0, 1, 2, 3, 4, 7, 0, 0, // 159 + 5, 7, 0, 0, 0, 0, 0, 0, // 160 + 0, 5, 7, 0, 0, 0, 0, 0, // 161 + 1, 5, 7, 0, 0, 0, 0, 0, // 162 + 0, 1, 5, 7, 0, 0, 0, 0, // 163 + 2, 5, 7, 0, 0, 0, 0, 0, // 164 + 0, 2, 5, 7, 0, 0, 0, 0, // 165 + 1, 2, 5, 7, 0, 0, 0, 0, // 166 + 0, 1, 2, 5, 7, 0, 0, 0, // 167 + 3, 5, 7, 0, 0, 0, 0, 0, // 168 + 0, 3, 5, 7, 0, 0, 0, 0, // 169 + 1, 3, 5, 7, 0, 0, 0, 0, // 170 + 0, 1, 3, 5, 7, 0, 0, 0, // 171 + 2, 3, 5, 7, 0, 0, 0, 0, // 172 + 0, 2, 3, 5, 7, 0, 0, 0, // 173 + 1, 2, 3, 5, 7, 0, 0, 0, // 174 + 0, 1, 2, 3, 5, 7, 0, 0, // 175 + 4, 5, 7, 0, 0, 0, 0, 0, // 176 + 0, 4, 5, 7, 0, 0, 0, 0, // 177 + 1, 4, 5, 7, 0, 0, 0, 0, // 178 + 0, 1, 4, 5, 7, 0, 0, 0, // 179 + 2, 4, 5, 7, 0, 0, 0, 0, // 180 + 0, 2, 4, 5, 7, 0, 0, 0, // 181 + 1, 2, 4, 5, 7, 0, 0, 0, // 182 + 0, 1, 2, 4, 5, 7, 0, 0, // 183 + 3, 4, 5, 7, 0, 0, 0, 0, // 184 + 0, 3, 4, 5, 7, 0, 0, 0, // 185 + 1, 3, 4, 5, 7, 0, 0, 0, // 186 + 0, 1, 3, 4, 5, 7, 0, 0, // 187 + 2, 3, 4, 5, 7, 0, 0, 0, // 188 + 0, 2, 3, 4, 5, 7, 0, 0, // 189 + 1, 2, 3, 4, 5, 7, 0, 0, // 190 + 0, 1, 2, 3, 4, 5, 7, 0, // 191 + 6, 7, 0, 0, 0, 0, 0, 0, // 192 + 0, 6, 7, 0, 0, 0, 0, 0, // 193 + 1, 6, 7, 0, 0, 0, 0, 0, // 194 + 0, 1, 6, 7, 0, 0, 0, 0, // 195 + 2, 6, 7, 0, 0, 0, 0, 0, // 196 + 0, 2, 6, 7, 0, 0, 0, 0, // 197 + 1, 2, 6, 7, 0, 0, 0, 0, // 198 + 0, 1, 2, 6, 7, 0, 0, 0, // 199 + 3, 6, 7, 0, 0, 0, 0, 0, // 200 + 0, 3, 6, 7, 0, 0, 0, 0, // 201 + 1, 3, 6, 7, 0, 0, 0, 0, // 202 + 0, 1, 3, 6, 7, 0, 0, 0, // 203 + 2, 3, 6, 7, 0, 0, 0, 0, // 204 + 0, 2, 3, 6, 7, 0, 0, 0, // 205 + 1, 2, 3, 6, 7, 0, 0, 0, // 206 + 0, 1, 2, 3, 6, 7, 0, 0, // 207 + 4, 6, 7, 0, 0, 0, 0, 0, // 208 + 0, 4, 6, 7, 0, 0, 0, 0, // 209 + 1, 4, 6, 7, 0, 0, 0, 0, // 210 + 0, 1, 4, 6, 7, 0, 0, 0, // 211 + 2, 4, 6, 7, 0, 0, 0, 0, // 212 + 0, 2, 4, 6, 7, 0, 0, 0, // 213 + 1, 2, 4, 6, 7, 0, 0, 0, // 214 + 0, 1, 2, 4, 6, 7, 0, 0, // 215 + 3, 4, 6, 7, 0, 0, 0, 0, // 216 + 0, 3, 4, 6, 7, 0, 0, 0, // 217 + 1, 3, 4, 6, 7, 0, 0, 0, // 218 + 0, 1, 3, 4, 6, 7, 0, 0, // 219 + 2, 3, 4, 6, 7, 0, 0, 0, // 220 + 0, 2, 3, 4, 6, 7, 0, 0, // 221 + 1, 2, 3, 4, 6, 7, 0, 0, // 222 + 0, 1, 2, 3, 4, 6, 7, 0, // 223 + 5, 6, 7, 0, 0, 0, 0, 0, // 224 + 0, 5, 6, 7, 0, 0, 0, 0, // 225 + 1, 5, 6, 7, 0, 0, 0, 0, // 226 + 0, 1, 5, 6, 7, 0, 0, 0, // 227 + 2, 5, 6, 7, 0, 0, 0, 0, // 228 + 0, 2, 5, 6, 7, 0, 0, 0, // 229 + 1, 2, 5, 6, 7, 0, 0, 0, // 230 + 0, 1, 2, 5, 6, 7, 0, 0, // 231 + 3, 5, 6, 7, 0, 0, 0, 0, // 232 + 0, 3, 5, 6, 7, 0, 0, 0, // 233 + 1, 3, 5, 6, 7, 0, 0, 0, // 234 + 0, 1, 3, 5, 6, 7, 0, 0, // 235 + 2, 3, 5, 6, 7, 0, 0, 0, // 236 + 0, 2, 3, 5, 6, 7, 0, 0, // 237 + 1, 2, 3, 5, 6, 7, 0, 0, // 238 + 0, 1, 2, 3, 5, 6, 7, 0, // 239 + 4, 5, 6, 7, 0, 0, 0, 0, // 240 + 0, 4, 5, 6, 7, 0, 0, 0, // 241 + 1, 4, 5, 6, 7, 0, 0, 0, // 242 + 0, 1, 4, 5, 6, 7, 0, 0, // 243 + 2, 4, 5, 6, 7, 0, 0, 0, // 244 + 0, 2, 4, 5, 6, 7, 0, 0, // 245 + 1, 2, 4, 5, 6, 7, 0, 0, // 246 + 0, 1, 2, 4, 5, 6, 7, 0, // 247 + 3, 4, 5, 6, 7, 0, 0, 0, // 248 + 0, 3, 4, 5, 6, 7, 0, 0, // 249 + 1, 3, 4, 5, 6, 7, 0, 0, // 250 + 0, 1, 3, 4, 5, 6, 7, 0, // 251 + 2, 3, 4, 5, 6, 7, 0, 0, // 252 + 0, 2, 3, 4, 5, 6, 7, 0, // 253 + 1, 2, 3, 4, 5, 6, 7, 0, // 254 + 0, 1, 2, 3, 4, 5, 6, 7 // 255 +}; + #ifdef __x86_64__ // Constants for the ML-KEM NTT and INTT functions and mulcache. // These are taken with no semantic change from mlkem-native @@ -12813,6 +13077,99 @@ int test_mlkem_rej_uniform(void) return 0; } +// Reference implementation of the ML-DSA rejection sampler: read 3-byte +// chunks from buf as little-endian 24-bit values, mask to 23 bits and keep +// those strictly less than q = 8380417. Stops at 256 accepted coefficients +// or when the buffer is exhausted. This matches the single-loop semantics +// of the FIPS-204 reference; the AVX2 implementation's internal main-vs-tail +// split does not change the externally observable result. +uint32_t reference_mldsa_rej_uniform(int32_t r[256], const uint8_t *buf) +{ + uint32_t ctr = 0; + uint64_t pos = 0; + while (ctr < 256 && pos + 3 <= 840) { + uint32_t v = ((uint32_t)buf[pos] + | ((uint32_t)buf[pos+1] << 8) + | ((uint32_t)buf[pos+2] << 16)) & 0x7FFFFFu; + pos += 3; + if (v < 8380417u) r[ctr++] = (int32_t)v; + } + return ctr; +} + +int test_mldsa_rej_uniform(void) +{ + // The mldsa_rej_uniform assembly is x86_64 AVX2 only. + if (get_arch_name() != ARCH_X86_64) { + return 0; + } + +#ifdef __x86_64__ + uint64_t t, i; + uint8_t inbuf[840]; + int32_t a[256]; + int32_t b[256] __attribute__((aligned(32))); + uint32_t ac, bc; + + printf("Testing mldsa_rej_uniform with %d cases\n", tests); + + for (t = 0; t < tests; ++t) { + for (i = 0; i < 840; ++i) inbuf[i] = (uint8_t) rand(); + + // Natural rejection probability with uniform random 23-bit values is + // only (2^23 - q) / 2^23 ~= 0.098%, so fully random input essentially + // always fills 256 accepted coefficients without exercising the + // rejection-pack path. For roughly a third of iterations, overwrite + // a fraction of 24-bit groups with values in [q, 2^23 - 1] to force + // both the AVX reject-and-compact path and the scalar-tail early + // exit (buffer exhausted before 256 accepts) to be taken. + if ((t & 1) == 0) { + uint32_t inject_mask = 0x3; // ~25% of groups rejected + for (uint64_t g = 0; g + 3 <= 840; g += 3) { + if ((uint32_t)(rand()) & inject_mask) continue; + uint32_t bad = 0x7fe001u + + ((uint32_t)rand() % (0x800000u - 0x7fe001u)); + inbuf[g] = (uint8_t)(bad & 0xff); + inbuf[g + 1] = (uint8_t)((bad >> 8) & 0xff); + // Keep the top (24th) bit untouched: the function masks to 23. + inbuf[g + 2] = (uint8_t)(((bad >> 16) & 0x7f) + | (inbuf[g + 2] & 0x80)); + } + } + + for (i = 0; i < 256; ++i) { a[i] = 0; b[i] = 0; } + + ac = reference_mldsa_rej_uniform(a, inbuf); + bc = mldsa_rej_uniform(b, inbuf, + (const uint64_t *)mldsa_rej_uniform_table); + + if (ac != bc) { + printf("Error in mldsa_rej_uniform count; code = %" PRIu32 + ", ref = %" PRIu32 "\n", bc, ac); + return 1; + } + for (i = 0; i < ac; ++i) { + if (a[i] != b[i]) { + printf("Error in mldsa_rej_uniform; element i = %" PRIu64 + "; code[i] = %" PRId32 + " while reference[i] = %" PRId32 "\n", + i, b[i], a[i]); + return 1; + } + } + if (VERBOSE) { + printf("OK:mldsa_rej_uniform, accepted %4" PRIu32 "/256 " + "[0x%08" PRIx32 ",...,0x%08" PRIx32 "]\n", + bc, b[0], b[(bc == 0) ? 0 : bc - 1]); + } + } + printf("All OK\n"); + return 0; +#else + return 0; +#endif +} + int test_mldsa_reduce(void) { // Skip test on non-x86_64 architectures @@ -16784,6 +17141,7 @@ int main(int argc, char *argv[]) functionaltest(all,"mldsa_pointwise_acc_l5",test_mldsa_pointwise_acc_l5); functionaltest(all,"mldsa_pointwise_acc_l7",test_mldsa_pointwise_acc_l7); functionaltest(all,"mldsa_reduce",test_mldsa_reduce); + functionaltest(all,"mldsa_rej_uniform",test_mldsa_rej_uniform); functionaltest(all,"mlkem_basemul_k2",test_mlkem_basemul_k2); functionaltest(all,"mlkem_basemul_k3",test_mlkem_basemul_k3); functionaltest(all,"mlkem_basemul_k4",test_mlkem_basemul_k4); diff --git a/tools/collect-signatures.py b/tools/collect-signatures.py index 8e28befb3..b55ac2174 100644 --- a/tools/collect-signatures.py +++ b/tools/collect-signatures.py @@ -343,6 +343,7 @@ def stripPrefixes(s, prefixes): "mldsa_pointwise_acc_l7_x86", "mldsa_pointwise_x86", "mldsa_reduce", + "mldsa_rej_uniform", "mlkem_frombytes", "mlkem_mulcache_compute_x86", "mlkem_ntt_x86", diff --git a/tools/list-x86-insns.sh b/tools/list-x86-insns.sh index cdf71ccfe..cdd722994 100755 --- a/tools/list-x86-insns.sh +++ b/tools/list-x86-insns.sh @@ -42,7 +42,7 @@ grep '\[' /tmp/all_instructions | grep -vi '^lea' >/tmp/fullmemory_instructions echo '.intel_syntax noprefix' >/tmp/register_instructions egrep -vi '^(j|call|ret|push|pop)' /tmp/other_instructions | grep -vwi rsp | grep -vwi rip | sort | uniq >>/tmp/register_instructions echo '.intel_syntax noprefix' >/tmp/memory_instructions -sed -e 's/\[.*\]/MEMORY_CELL/' /tmp/fullmemory_instructions | grep -vwi rsp | grep -vwi rip | sort | uniq | sed -e 's/MEMORY_CELL/[rsp+32]/' >>/tmp/memory_instructions +sed -e 's/\[.*\]/MEMORY_CELL/' /tmp/fullmemory_instructions | grep -vwi rsp | grep -vwi rip | grep -vwi movs | sort | uniq | sed -e 's/MEMORY_CELL/[rsp+32]/' >>/tmp/memory_instructions # Now turn them into the syntax for the simulator OCaml input diff --git a/x86/Makefile b/x86/Makefile index 20e06e8a3..fe95fa0e7 100644 --- a/x86/Makefile +++ b/x86/Makefile @@ -258,6 +258,7 @@ BIGNUM_OBJ = curve25519/bignum_add_p25519.o \ mldsa/mldsa_pointwise_acc_l5.o \ mldsa/mldsa_pointwise_acc_l7.o \ mldsa/mldsa_reduce.o \ + mldsa/mldsa_rej_uniform.o \ mlkem/mlkem_basemul_k2.o \ mlkem/mlkem_basemul_k3.o \ mlkem/mlkem_basemul_k4.o \ diff --git a/x86/allowed_asm b/x86/allowed_asm index 20756bf79..456ed5a4c 100644 --- a/x86/allowed_asm +++ b/x86/allowed_asm @@ -127,6 +127,7 @@ : test$ : testq$ : vmovd$ +: vmovmskps$ : vmovq$ : vmovdqa$ : vmovdqu$ @@ -159,6 +160,7 @@ : vpmaddubsw$ : vpmaddwd$ : vpackuswb$ +: vpmovzxbd$ : vpmuldq$ : vpmulhrsw$ : vpmulhw$ @@ -184,6 +186,7 @@ : vpunpckhqdq$ : vpunpcklqdq$ : vpxor$ +: vzeroupper$ : xchg$ : xor$ : xorl$ diff --git a/x86/mldsa/mldsa_rej_uniform.S b/x86/mldsa/mldsa_rej_uniform.S new file mode 100644 index 000000000..4aefe4079 --- /dev/null +++ b/x86/mldsa/mldsa_rej_uniform.S @@ -0,0 +1,147 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +// ---------------------------------------------------------------------------- +// Uniform rejection sampling for ML-DSA +// Input buf[840] (uint8_t); output r[256] (int32_t); table[2048] (uint64_t) +// Returns: number of valid coefficients in r (at most 256) +// +// This function implements the rejection-sampling loop for ML-DSA, extracting +// 23-bit coefficients from packed 24-bit input bytes and keeping only those +// strictly less than q = 8380417. A main AVX2 loop processes 24 bytes (8 +// coefficients) per iteration using VPERMQ+VPSHUFB extraction, VPAND masking, +// VPSUBD+VMOVMSKPS rejection, and VPERMD+table compaction. A scalar tail +// handles any remaining bytes after the main loop exits. +// +// This implementation is derived from the public domain AVX2 Dilithium +// implementation from CRYSTALS-Dilithium optimized AVX2 implementation by +// Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé +// (https://github.com/pq-crystals/dilithium/tree/master/avx2) +// +// extern uint32_t mldsa_rej_uniform +// (int32_t r[static 256], +// const uint8_t buf[static 840], +// const uint64_t table[static 256]); +// +// Standard x86-64 ABI: RDI = r, RSI = buf, RDX = table +// Microsoft x64 ABI: RCX = r, RDX = buf, R8 = table +// ---------------------------------------------------------------------------- + +#include "_internal_s2n_bignum_x86.h" + + .intel_syntax noprefix + S2N_BN_SYM_VISIBILITY_DIRECTIVE(mldsa_rej_uniform) + S2N_BN_FUNCTION_TYPE_DIRECTIVE(mldsa_rej_uniform) + S2N_BN_SYM_PRIVACY_DIRECTIVE(mldsa_rej_uniform) + .text + +S2N_BN_SYMBOL(mldsa_rej_uniform): + + _CET_ENDBR + +#if WINDOWS_ABI + push rdi + push rsi + sub rsp, 160 + movdqu [rsp+0], xmm6 + movdqu [rsp+16], xmm7 + movdqu [rsp+32], xmm8 + movdqu [rsp+48], xmm9 + movdqu [rsp+64], xmm10 + movdqu [rsp+80], xmm11 + movdqu [rsp+96], xmm12 + movdqu [rsp+112], xmm13 + movdqu [rsp+128], xmm14 + movdqu [rsp+144], xmm15 + mov rdi, rcx + mov rsi, rdx + mov rdx, r8 +#endif + + // Shuffle mask: expand 24 bytes (8 x 3-byte coefficients) into + // 8 x 4-byte lanes (with a zero high byte in each). + mov r10, 0xff050403ff020100 + vmovq xmm0, r10 + mov r10, 0xff0b0a09ff080706 + vpinsrq xmm0, xmm0, r10, 0x1 + mov r10, 0xff090807ff060504 + vmovq xmm3, r10 + mov r10, 0xff0f0e0dff0c0b0a + vpinsrq xmm3, xmm3, r10, 0x1 + vinserti128 ymm0, ymm0, xmm3, 0x1 + + // Mask 0x7fffff in all 8 lanes (keep low 23 bits) + mov r8d, 0x7fffff + vmovd xmm1, r8d + vpbroadcastd ymm1, xmm1 + + // Threshold q = 0x7fe001 in all 8 lanes + mov r8d, 0x7fe001 + vmovd xmm2, r8d + vpbroadcastd ymm2, xmm2 + + // rax = accepted count, rcx = byte offset into buf + xor eax, eax + xor ecx, ecx + +Lmldsa_rej_uniform_loop: + // Exit to scalar tail if we have 248 or more accepted (next 8 might + // overflow) or if byte offset is past 808 (would read past buf+840-24). + cmp eax, 0xf8 + ja Lmldsa_rej_uniform_scalar + cmp ecx, 0x328 + ja Lmldsa_rej_uniform_scalar + + vmovdqu ymm3, YMMWORD PTR [rsi+rcx] + add ecx, 0x18 + vpermq ymm3, ymm3, 0x94 + vpshufb ymm3, ymm3, ymm0 + vpand ymm3, ymm3, ymm1 + vpsubd ymm4, ymm3, ymm2 + vmovmskps r8d, ymm4 + popcnt r9d, r8d + vmovq xmm4, QWORD PTR [rdx+8*r8] + vpmovzxbd ymm4, xmm4 + vpermd ymm3, ymm4, ymm3 + vmovdqu YMMWORD PTR [rdi+4*rax], ymm3 + add eax, r9d + jmp Lmldsa_rej_uniform_loop + +Lmldsa_rej_uniform_scalar: + cmp eax, 0x100 + jae Lmldsa_rej_uniform_done + cmp ecx, 0x345 + ja Lmldsa_rej_uniform_done + movzx r8d, WORD PTR [rsi+rcx] + movzx r9d, BYTE PTR [rsi+rcx+2] + shl r9d, 0x10 + or r8d, r9d + and r8d, 0x7fffff + add ecx, 0x3 + cmp r8d, 0x7fe001 + jae Lmldsa_rej_uniform_scalar + mov DWORD PTR [rdi+4*rax], r8d + add eax, 0x1 + jmp Lmldsa_rej_uniform_scalar + +Lmldsa_rej_uniform_done: +#if WINDOWS_ABI + movdqu xmm6, [rsp+0] + movdqu xmm7, [rsp+16] + movdqu xmm8, [rsp+32] + movdqu xmm9, [rsp+48] + movdqu xmm10, [rsp+64] + movdqu xmm11, [rsp+80] + movdqu xmm12, [rsp+96] + movdqu xmm13, [rsp+112] + movdqu xmm14, [rsp+128] + movdqu xmm15, [rsp+144] + add rsp, 160 + pop rsi + pop rdi +#endif + ret + +#if defined(__linux__) && defined(__ELF__) + .section .note.GNU-stack,"",%progbits +#endif diff --git a/x86/proofs/decode.ml b/x86/proofs/decode.ml index ae9a86dd5..c341c7f2a 100644 --- a/x86/proofs/decode.ml +++ b/x86/proofs/decode.ml @@ -695,6 +695,17 @@ let decode_aux = new_definition `!pfxs rex l. decode_aux pfxs rex l = match pfxs with | (T, Rep0, SG0) -> SOME (VPMULDQ (mmreg reg sz) (mmreg v sz) (simd_of_RM sz rm),l) | _ -> NONE) + | [0x31:8] -> if word_not v = (word 0b1111:4 word) then + let sz = vexL_size L in + (read_ModRM rex l >>= \((reg,rm),l). + let sop = if is_memop rm then + (if L then operand_of_RM Full_64 rm + else operand_of_RM Lower_32 rm) + else simd_of_RM Lower_128 rm in + match pfxs with + | (T, Rep0, SG0) -> SOME (VPMOVZXBD (mmreg reg sz) sop,l) + | _ -> NONE) + else NONE | [0x36:8] -> let sz = vexL_size L in (read_ModRM rex l >>= \((reg,rm),l). @@ -771,6 +782,15 @@ let decode_aux = new_definition `!pfxs rex l. decode_aux pfxs rex l = match pfxs with | (F, RepZ, SG0) -> SOME (VMOVSLDUP (mmreg reg sz) (simd_of_RM sz rm),l) | _ -> NONE) + | [0x50:8] -> if word_not v = (word 0b1111:4 word) then + (read_ModRM rex l >>= \((reg,rm),l). + let dest = %(gpr_adjust reg Lower_32) in + let sz = vexL_size L in + let src = simd_of_RM sz rm in + match pfxs with + | (F, Rep0, SG0) -> SOME (VMOVMSKPS dest src, l) + | _ -> NONE) + else NONE | [0x16:8] -> let sz = vexL_size L in (read_ModRM rex l >>= \((reg,rm),l). @@ -787,6 +807,12 @@ let decode_aux = new_definition `!pfxs rex l. decode_aux pfxs rex l = SOME (VMOVHPD dst src, l) | _ -> NONE)) else NONE + | [0x77:8] -> if word_not v = (word 0b1111:4 word) then + (if L then NONE else + match pfxs with + | (F, Rep0, SG0) -> SOME (VZEROUPPER, l) + | _ -> NONE) + else NONE | [0x6c:8] -> let sz = vexL_size L in (read_ModRM rex l >>= \((reg,rm),l). diff --git a/x86/proofs/instruction.ml b/x86/proofs/instruction.ml index a09d1fef8..850cc28fb 100644 --- a/x86/proofs/instruction.ml +++ b/x86/proofs/instruction.ml @@ -318,6 +318,7 @@ let instruction_INDUCTION,instruction_RECURSION = define_type | TEST operand operand | TZCNT operand operand | VMOVD operand operand + | VMOVMSKPS operand operand | VMOVQ operand operand | VMOVDQA operand operand | VMOVDQU operand operand @@ -351,6 +352,7 @@ let instruction_INDUCTION,instruction_RECURSION = define_type | VPBLENDVB operand operand operand operand | VPMADDUBSW operand operand operand | VPMADDWD operand operand operand + | VPMOVZXBD operand operand | VPMULDQ operand operand operand | VPMULHRSW operand operand operand | VPMULHW operand operand operand @@ -376,6 +378,7 @@ let instruction_INDUCTION,instruction_RECURSION = define_type | VPUNPCKHQDQ operand operand operand | VPUNPCKLQDQ operand operand operand | VPXOR operand operand operand + | VZEROUPPER | XCHG operand operand | XOR operand operand";; diff --git a/x86/proofs/mldsa_rej_uniform.ml b/x86/proofs/mldsa_rej_uniform.ml new file mode 100644 index 000000000..6682b6f31 --- /dev/null +++ b/x86/proofs/mldsa_rej_uniform.ml @@ -0,0 +1,7240 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* ML-DSA Rejection uniform sampling (AVX2). *) +(* ========================================================================= *) + +needs "x86/proofs/base.ml";; + +needs "x86/proofs/mldsa_rej_uniform_table.ml";; + +(*** print_literal_from_elf "x86/mldsa/mldsa_rej_uniform.o";; + ***) + +let mldsa_rej_uniform_mc = define_assert_from_elf + "mldsa_rej_uniform_mc" "x86/mldsa/mldsa_rej_uniform.o" +(*** BYTECODE START ***) +[ + 0xf3; 0x0f; 0x1e; 0xfa; (* ENDBR64 *) + 0x49; 0xba; 0x00; 0x01; 0x02; 0xff; 0x03; 0x04; 0x05; 0xff; + (* MOV (% r10) (Imm64 (word 18376098269764911360)) *) + 0xc4; 0xc1; 0xf9; 0x6e; 0xc2; + (* VMOVQ (%_% xmm0) (% r10) *) + 0x49; 0xba; 0x06; 0x07; 0x08; 0xff; 0x09; 0x0a; 0x0b; 0xff; + (* MOV (% r10) (Imm64 (word 18377793742465140486)) *) + 0xc4; 0xc3; 0xf9; 0x22; 0xc2; 0x01; + (* VPINSRQ (%_% xmm0) (%_% xmm0) (% r10) (Imm8 (word 1)) *) + 0x49; 0xba; 0x04; 0x05; 0x06; 0xff; 0x07; 0x08; 0x09; 0xff; + (* MOV (% r10) (Imm64 (word 18377228584898397444)) *) + 0xc4; 0xc1; 0xf9; 0x6e; 0xda; + (* VMOVQ (%_% xmm3) (% r10) *) + 0x49; 0xba; 0x0a; 0x0b; 0x0c; 0xff; 0x0d; 0x0e; 0x0f; 0xff; + (* MOV (% r10) (Imm64 (word 18378924057598626570)) *) + 0xc4; 0xc3; 0xe1; 0x22; 0xda; 0x01; + (* VPINSRQ (%_% xmm3) (%_% xmm3) (% r10) (Imm8 (word 1)) *) + 0xc4; 0xe3; 0x7d; 0x38; 0xc3; 0x01; + (* VINSERTI128 (%_% ymm0) (%_% ymm0) (%_% xmm3) (Imm8 (word 1)) *) + 0x41; 0xb8; 0xff; 0xff; 0x7f; 0x00; + (* MOV (% r8d) (Imm32 (word 8388607)) *) + 0xc4; 0xc1; 0x79; 0x6e; 0xc8; + (* VMOVD (%_% xmm1) (% r8d) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xc9; + (* VPBROADCASTD (%_% ymm1) (%_% xmm1) *) + 0x41; 0xb8; 0x01; 0xe0; 0x7f; 0x00; + (* MOV (% r8d) (Imm32 (word 8380417)) *) + 0xc4; 0xc1; 0x79; 0x6e; 0xd0; + (* VMOVD (%_% xmm2) (% r8d) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xd2; + (* VPBROADCASTD (%_% ymm2) (%_% xmm2) *) + 0x31; 0xc0; (* XOR (% eax) (% eax) *) + 0x31; 0xc9; (* XOR (% ecx) (% ecx) *) + 0x3d; 0xf8; 0x00; 0x00; 0x00; + (* CMP (% eax) (Imm32 (word 248)) *) + 0x77; 0x46; (* JA (Imm8 (word 70)) *) + 0x81; 0xf9; 0x28; 0x03; 0x00; 0x00; + (* CMP (% ecx) (Imm32 (word 808)) *) + 0x77; 0x3e; (* JA (Imm8 (word 62)) *) + 0xc5; 0xfe; 0x6f; 0x1c; 0x0e; + (* VMOVDQU (%_% ymm3) (Memop Word256 (%%% (rsi,0,rcx))) *) + 0x83; 0xc1; 0x18; (* ADD (% ecx) (Imm8 (word 24)) *) + 0xc4; 0xe3; 0xfd; 0x00; 0xdb; 0x94; + (* VPERMQ (%_% ymm3) (%_% ymm3) (Imm8 (word 148)) *) + 0xc4; 0xe2; 0x65; 0x00; 0xd8; + (* VPSHUFB (%_% ymm3) (%_% ymm3) (%_% ymm0) *) + 0xc5; 0xe5; 0xdb; 0xd9; (* VPAND (%_% ymm3) (%_% ymm3) (%_% ymm1) *) + 0xc5; 0xe5; 0xfa; 0xe2; (* VPSUBD (%_% ymm4) (%_% ymm3) (%_% ymm2) *) + 0xc5; 0x7c; 0x50; 0xc4; (* VMOVMSKPS (% r8d) (%_% ymm4) *) + 0xf3; 0x45; 0x0f; 0xb8; 0xc8; + (* POPCNT (% r9d) (% r8d) *) + 0xc4; 0xa1; 0x7a; 0x7e; 0x24; 0xc2; + (* VMOVQ (%_% xmm4) (Memop Quadword (%%% (rdx,3,r8))) *) + 0xc4; 0xe2; 0x7d; 0x31; 0xe4; + (* VPMOVZXBD (%_% ymm4) (%_% xmm4) *) + 0xc4; 0xe2; 0x5d; 0x36; 0xdb; + (* VPERMD (%_% ymm3) (%_% ymm4) (%_% ymm3) *) + 0xc5; 0xfe; 0x7f; 0x1c; 0x87; + (* VMOVDQU (Memop Word256 (%%% (rdi,2,rax))) (%_% ymm3) *) + 0x44; 0x01; 0xc8; (* ADD (% eax) (% r9d) *) + 0xeb; 0xb3; (* JMP (Imm8 (word 179)) *) + 0x3d; 0x00; 0x01; 0x00; 0x00; + (* CMP (% eax) (Imm32 (word 256)) *) + 0x73; 0x36; (* JAE (Imm8 (word 54)) *) + 0x81; 0xf9; 0x45; 0x03; 0x00; 0x00; + (* CMP (% ecx) (Imm32 (word 837)) *) + 0x77; 0x2e; (* JA (Imm8 (word 46)) *) + 0x44; 0x0f; 0xb7; 0x04; 0x0e; + (* MOVZX (% r8d) (Memop Word (%%% (rsi,0,rcx))) *) + 0x44; 0x0f; 0xb6; 0x4c; 0x0e; 0x02; + (* MOVZX (% r9d) (Memop Byte (%%%% (rsi,0,rcx,&2))) *) + 0x41; 0xc1; 0xe1; 0x10; (* SHL (% r9d) (Imm8 (word 16)) *) + 0x45; 0x09; 0xc8; (* OR (% r8d) (% r9d) *) + 0x41; 0x81; 0xe0; 0xff; 0xff; 0x7f; 0x00; + (* AND (% r8d) (Imm32 (word 8388607)) *) + 0x83; 0xc1; 0x03; (* ADD (% ecx) (Imm8 (word 3)) *) + 0x41; 0x81; 0xf8; 0x01; 0xe0; 0x7f; 0x00; + (* CMP (% r8d) (Imm32 (word 8380417)) *) + 0x73; 0xcc; (* JAE (Imm8 (word 204)) *) + 0x44; 0x89; 0x04; 0x87; (* MOV (Memop Doubleword (%%% (rdi,2,rax))) (% r8d) *) + 0x83; 0xc0; 0x01; (* ADD (% eax) (Imm8 (word 1)) *) + 0xeb; 0xc3; (* JMP (Imm8 (word 195)) *) + 0xc3 (* RET *) +];; +(*** BYTECODE END ***) + +let mldsa_rej_uniform_tmc = + define_trimmed "mldsa_rej_uniform_tmc" mldsa_rej_uniform_mc;; + +let MLDSA_REJ_UNIFORM_EXEC = X86_MK_CORE_EXEC_RULE mldsa_rej_uniform_tmc;; + +(* ========================================================================= *) +(* Pre-helper lemmas (VPSUBD_SIGN_BIT_BOUNDED, SIGN_BIT_BITVAL). *) +(* ========================================================================= *) + +(* === Lemmas that helpers file depends on === *) + +let VPSUBD_SIGN_BIT_BOUNDED = prove + (`!x:int32. val x < 8388608 + ==> (bit 31 (word_sub x (word 8380417)) <=> val x < 8380417)`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[BIT_VAL; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[VAL_WORD_SUB; DIMINDEX_32; VAL_WORD] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_CASES_TAC `val(x:int32) < 8380417` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN + `(val(x:int32) + 4286586879) MOD 4294967296 = val x + 4286586879` + SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC(MESON[ODD; ARITH_RULE `ODD 1`] `n = 1 ==> ODD n`) THEN + MATCH_MP_TAC DIV_UNIQ THEN + EXISTS_TAC `val(x:int32) + 2139103231` THEN ASM_ARITH_TAC; + REWRITE_TAC[NOT_ODD] THEN + SUBGOAL_THEN + `(val(x:int32) + 4286586879) MOD 4294967296 = val x - 8380417` + SUBST1_TAC THENL + [SUBGOAL_THEN + `val(x:int32) + 4286586879 = (val x - 8380417) + 1 * 4294967296` + SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[MOD_MULT_ADD] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; + ALL_TAC] THEN + SIMP_TAC[DIV_LT; EVEN] THEN ASM_ARITH_TAC]);; + +let SIGN_BIT_BITVAL = prove + (`!x0:int32. val x0 < 8388608 + ==> bitval(bit 31 (word_sub x0 (word 8380417):int32)) = bitval(val x0 < 8380417)`, + GEN_TAC THEN DISCH_TAC THEN AP_TERM_TAC THEN + REWRITE_TAC[BIT_VAL; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[VAL_WORD_SUB; DIMINDEX_32; VAL_WORD] THEN CONV_TAC NUM_REDUCE_CONV THEN + ASM_CASES_TAC `val(x0:int32) < 8380417` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `(val(x0:int32) + 4286586879) MOD 4294967296 = val x0 + 4286586879` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC(MESON[ODD; ARITH_RULE `ODD 1`] `n = 1 ==> ODD n`) THEN + MATCH_MP_TAC DIV_UNIQ THEN EXISTS_TAC `val(x0:int32) + 2139103231` THEN ASM_ARITH_TAC; + SUBGOAL_THEN `(val(x0:int32) + 4286586879) MOD 4294967296 = val x0 - 8380417` SUBST1_TAC THENL + [SUBGOAL_THEN `val(x0:int32) + 4286586879 = (val x0 - 8380417) + 1 * 4294967296` SUBST1_TAC THENL + [ASM_ARITH_TAC; REWRITE_TAC[MOD_MULT_ADD] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC]; + REWRITE_TAC[NOT_ODD] THEN SIMP_TAC[DIV_LT; EVEN] THEN ASM_ARITH_TAC]]);; + +(* ========================================================================= *) +(* Helper lemmas. *) +(* ========================================================================= *) + +(* Helper lemmas for mldsa_rej_uniform proof - VMOVMSKPS+POPCNT chain *) + +(* word_popcount is preserved through word_zx *) +let WORD_POPCOUNT_WORD_ZX = prove + (`!(w:N word). dimindex(:N) <= dimindex(:M) ==> word_popcount(word_zx w:M word) = word_popcount w`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[word_popcount] THEN AP_TERM_TAC THEN + REWRITE_TAC[EXTENSION; IN_ELIM_THM; bits_of_word; BIT_WORD_ZX] THEN + X_GEN_TAC `j:num` THEN EQ_TAC THEN + REPEAT STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + ASM_MESON_TAC[BIT_TRIVIAL; NOT_LT; LTE_TRANS]);; + +(* word_of_bits VMOVMSKPS pattern = sum of bitvals *) +let VMOVMSKPS_BYTE_EQ = prove + (`!x:int256. word_of_bits(\i. i < 8 /\ bit(32*i+31) x):byte = + word(bitval(bit 31 x) + 2 * bitval(bit 63 x) + 4 * bitval(bit 95 x) + + 8 * bitval(bit 127 x) + 16 * bitval(bit 159 x) + 32 * bitval(bit 191 x) + + 64 * bitval(bit 223 x) + 128 * bitval(bit 255 x))`, + GEN_TAC THEN + REWRITE_TAC[WORD_OF_BITS_AS_WORD_ALT; DIMINDEX_8] THEN + CONV_TAC NUM_REDUCE_CONV THEN AP_TERM_TAC THEN + CONV_TAC(LAND_CONV EXPAND_NSUM_CONV) THEN + REWRITE_TAC[IN] THEN CONV_TAC(DEPTH_CONV BETA_CONV) THEN + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC);; + +(* bit(32*k+31)(x:int256) = bit 31(word_subword x (32*k,32):int32) *) +let BIT_SUBWORD_256 = prove + ((rand o concl o (EXPAND_CASES_CONV THENC NUM_REDUCE_CONV)) + `!i. i < 8 ==> + !x:int256. bit(32*i+31) x = bit 31 (word_subword x (32*i,32):int32)`, + CONV_TAC WORD_BLAST);; + +(* Combined: word_popcount of word_of_bits = word_popcount of bitval sum *) +let VMOVMSKPS_POPCOUNT_EQ = prove + (`!x:int256. + word_popcount(word_of_bits(\i. i < 8 /\ bit(32*i+31) x):byte) = + word_popcount(word( + bitval(bit 31 (word_subword x (0,32):int32)) + + 2 * bitval(bit 31 (word_subword x (32,32):int32)) + + 4 * bitval(bit 31 (word_subword x (64,32):int32)) + + 8 * bitval(bit 31 (word_subword x (96,32):int32)) + + 16 * bitval(bit 31 (word_subword x (128,32):int32)) + + 32 * bitval(bit 31 (word_subword x (160,32):int32)) + + 64 * bitval(bit 31 (word_subword x (192,32):int32)) + + 128 * bitval(bit 31 (word_subword x (224,32):int32))):byte)`, + GEN_TAC THEN AP_TERM_TAC THEN + REWRITE_TAC[VMOVMSKPS_BYTE_EQ; BIT_SUBWORD_256]);; + +(* Extract bit 31 from each lane of nested word_join of int32's *) +let BIT_NESTED_JOIN_8 = REWRITE_RULE[LET_DEF; LET_END_DEF] (prove + (`!(a0:int32) (a1:int32) (a2:int32) (a3:int32) (a4:int32) (a5:int32) (a6:int32) (a7:int32). + let x:int256 = word_join + (word_join (word_join a7 a6:int64) (word_join a5 a4:int64):int128) + (word_join (word_join a3 a2:int64) (word_join a1 a0:int64):int128) in + bit 31 (word_subword x (0,32):int32) = bit 31 a0 /\ + bit 31 (word_subword x (32,32):int32) = bit 31 a1 /\ + bit 31 (word_subword x (64,32):int32) = bit 31 a2 /\ + bit 31 (word_subword x (96,32):int32) = bit 31 a3 /\ + bit 31 (word_subword x (128,32):int32) = bit 31 a4 /\ + bit 31 (word_subword x (160,32):int32) = bit 31 a5 /\ + bit 31 (word_subword x (192,32):int32) = bit 31 a6 /\ + bit 31 (word_subword x (224,32):int32) = bit 31 a7`, + REPEAT GEN_TAC THEN CONV_TAC let_CONV THEN + REWRITE_TAC[BIT_WORD_SUBWORD; BIT_WORD_JOIN; + DIMINDEX_32; DIMINDEX_64; DIMINDEX_128; DIMINDEX_256] THEN + CONV_TAC NUM_REDUCE_CONV));; + +(* 3-byte word_join with zero high byte = word_zx of 24-bit join *) +let BYTE_JOIN_ZX = prove + (`!b0 b1 b2:byte. + word_join (word_join (word 0:byte) (b2:byte):int16) + (word_join (b1:byte) (b0:byte):int16):int32 = + word_zx(word_join (word_join (b2:byte) (b1:byte):int16) (b0:byte):24 word):int32`, + REPEAT GEN_TAC THEN CONV_TAC WORD_BLAST);; + +(* word_and with 0x7FFFFF mask on word_zx of 24-bit = word_zx of 23-bit subword *) +let BYTE_JOIN_SUBWORD_ZX = prove + (`!b0 b1 b2:byte. + word_and (word_zx(word_join (word_join (b2:byte) (b1:byte):int16) (b0:byte):24 word):int32) + (word 8388607:int32) = + word_zx(word_subword (word_join (word_join (b2:byte) (b1:byte):int16) (b0:byte):24 word) (0,23):23 word):int32`, + REPEAT GEN_TAC THEN CONV_TAC WORD_BLAST);; + +(* Little-endian 3-byte reconstruction at num level *) +let BYTES3_NUM = prove + (`!n. n MOD 256 + 256 * (n DIV 256) MOD 256 + 65536 * (n DIV 65536) MOD 256 = n MOD 16777216`, + GEN_TAC THEN + SUBGOAL_THEN `16777216 = 65536 * 256` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `65536 = 256 * 256` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[GSYM DIV_DIV; MOD_MULT_MOD] THEN + REWRITE_TAC[ARITH_RULE `256 * 256 * 256 = 256 * (256 * 256)`] THEN + REWRITE_TAC[MOD_MULT_MOD] THEN + MP_TAC(SPEC `256` (SPEC `n:num` DIVISION)) THEN + MP_TAC(SPEC `256` (SPEC `n DIV 256` DIVISION)) THEN + REWRITE_TAC[ARITH_RULE `~(256 = 0)`] THEN ARITH_TAC);; + +(* val of 3-byte word_join *) +let BYTE_JOIN_VAL = prove + (`!b0 b1 b2:byte. + val(word_join (word_join (b2:byte) (b1:byte):int16) (b0:byte) : 24 word) = + val b0 + 256 * val b1 + 65536 * val b2`, + REPEAT GEN_TAC THEN + REWRITE_TAC[VAL_WORD_JOIN; DIMINDEX_8; DIMINDEX_16] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `b0:byte` VAL_BOUND) THEN + MP_TAC(ISPEC `b1:byte` VAL_BOUND) THEN + MP_TAC(ISPEC `b2:byte` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_8] THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT STRIP_TAC THEN + SUBGOAL_THEN `256 * val(b2:byte) + val(b1:byte) < 65536` + (fun th -> SIMP_TAC[th; MOD_LT]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `256 * (256 * val(b2:byte) + val(b1:byte)) + val(b0:byte) < 16777216` + (fun th -> SIMP_TAC[th; MOD_LT]) THENL [ASM_ARITH_TAC; ARITH_TAC]);; + +(* val of byte_join from word n : int256 = n DIV 2^ofs MOD 2^24 *) +let BYTE_JOIN_VAL_WORD = prove + (`!n ofs. + val(word_join (word_join (word_subword (word n:int256) (ofs+16,8):byte) + (word_subword (word n:int256) (ofs+8,8):byte):int16) + (word_subword (word n:int256) (ofs,8):byte) : 24 word) = + (n MOD 2 EXP 256) DIV 2 EXP ofs MOD 2 EXP 24`, + REPEAT GEN_TAC THEN + REWRITE_TAC[BYTE_JOIN_VAL; VAL_WORD_SUBWORD; VAL_WORD; DIMINDEX_8] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[EXP_ADD; GSYM DIV_DIV] THEN CONV_TAC NUM_REDUCE_CONV THEN + SPEC_TAC(`(n MOD 2 EXP 256) DIV 2 EXP ofs`, `m:num`) THEN + REWRITE_TAC[BYTES3_NUM]);; + +(* Full coefficient lemma: byte_join + 23-bit mask from word n = word(n DIV 2^ofs MOD 2^23) *) +let COEFF_BYTE_JOIN_WORD = prove + (`!n ofs. + word_zx(word_subword + (word_join (word_join (word_subword (word n:int256) (ofs+16,8):byte) + (word_subword (word n:int256) (ofs+8,8):byte):int16) + (word_subword (word n:int256) (ofs,8):byte) : 24 word) + (0,23) : 23 word) : int32 = + word((n MOD 2 EXP 256) DIV 2 EXP ofs MOD 2 EXP 23)`, + REPEAT GEN_TAC THEN + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_SUBWORD; VAL_WORD; + DIMINDEX_8; DIMINDEX_32] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[BYTE_JOIN_VAL_WORD; DIV_1] THEN + ONCE_REWRITE_TAC[GSYM(NUM_REDUCE_CONV `2 EXP 24`)] THEN + ONCE_REWRITE_TAC[GSYM(NUM_REDUCE_CONV `2 EXP 23`)] THEN + ONCE_REWRITE_TAC[GSYM(NUM_REDUCE_CONV `2 EXP 32`)] THEN + REWRITE_TAC[MOD_MOD_EXP_MIN] THEN CONV_TAC NUM_REDUCE_CONV);; + +(* Reduce MOD 2^256 to MOD 2^192 in DIV/MOD extraction context *) +let MOD_256_192 = prove + (`!n k. k + 23 <= 192 ==> + (n MOD 2 EXP 256) DIV (2 EXP k) MOD (2 EXP 23) = + (n MOD 2 EXP 192) DIV (2 EXP k) MOD (2 EXP 23)`, + REPEAT STRIP_TAC THEN + REWRITE_TAC[DIV_MOD; GSYM EXP_ADD; MOD_MOD_EXP_MIN] THEN + AP_THM_TAC THEN AP_TERM_TAC THEN AP_TERM_TAC THEN AP_TERM_TAC THEN + ASM_ARITH_TAC);; + +(* word_popcount is preserved through word_zx *) +(* val(word n : 24 word) MOD 2^23 = n MOD 2^23 — avoids MOD_MOD_EXP_MIN loop *) +let VAL_WORD_24_MOD_23 = prove + (`!n. val(word n : 24 word) MOD 2 EXP 23 = n MOD 2 EXP 23`, + GEN_TAC THEN REWRITE_TAC[VAL_WORD] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN + REWRITE_TAC[MOD_MOD_EXP_MIN] THEN + CONV_TAC NUM_REDUCE_CONV);; + +(* MAP of REJ_SAMPLE coefficient extraction = concrete list *) +let MAP_REJ_COEFFS = prove + (`!l:(24 word)list. LENGTH l = 8 ==> + MAP (\x:24 word. word(val x MOD 2 EXP 23):int32) l = + [word(num_of_wordlist l MOD 2 EXP 23); + word(num_of_wordlist l DIV 2 EXP 24 MOD 2 EXP 23); + word(num_of_wordlist l DIV 2 EXP 48 MOD 2 EXP 23); + word(num_of_wordlist l DIV 2 EXP 72 MOD 2 EXP 23); + word(num_of_wordlist l DIV 2 EXP 96 MOD 2 EXP 23); + word(num_of_wordlist l DIV 2 EXP 120 MOD 2 EXP 23); + word(num_of_wordlist l DIV 2 EXP 144 MOD 2 EXP 23); + word(num_of_wordlist l DIV 2 EXP 168 MOD 2 EXP 23)]`, + GEN_TAC THEN DISCH_TAC THEN REWRITE_TAC[LIST_EQ] THEN + CONV_TAC(ONCE_DEPTH_CONV LENGTH_CONV) THEN + REWRITE_TAC[LENGTH_MAP] THEN ASM_REWRITE_TAC[] THEN + ASM_SIMP_TAC[EL_MAP; EL_NUM_OF_WORDLIST; + ARITH_RULE `LENGTH l = 8 ==> (n < 8 ==> n < LENGTH l)`] THEN + REWRITE_TAC[VAL_WORD_24_MOD_23] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN + CONV_TAC EXPAND_CASES_CONV THEN REPEAT CONJ_TAC THEN + CONV_TAC(ONCE_DEPTH_CONV EL_CONV) THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + REWRITE_TAC[DIV_1]);; + +(* NOTE: REJ_SAMPLE_COEFFS was moved to the main proof file because + it depends on REJ_SAMPLE which is defined there, not in the helpers. *) + +(* Memory bytes split: read(bytes(a, m+n)) = read(bytes(a,m)) + 2^(8m) * read(bytes(a+m, n)) *) +let MEMORY_BYTES_SPLIT = prove + (`!a m n s. read (memory :> bytes (a:int64, m + n)) s = + read (memory :> bytes (a, m)) s + + 2 EXP (8 * m) * read (memory :> bytes (word_add a (word m), n)) s`, + REWRITE_TAC[READ_COMPONENT_COMPOSE; READ_BYTES_COMBINE]);; + +(* CMP_MASK_CORRECT: VMOVMSKPS(VPSUBD(coeffs, Q)) = bitval sum of (val c_k < Q). + Connects the comparison mask byte to the FILTER predicate. *) +let CMP_MASK_CORRECT = prove( + `!c0 c1 c2 c3 c4 c5 c6 c7:int32. + val c0 < 8388608 /\ val c1 < 8388608 /\ val c2 < 8388608 /\ + val c3 < 8388608 /\ val c4 < 8388608 /\ val c5 < 8388608 /\ + val c6 < 8388608 /\ val c7 < 8388608 ==> + val(word_zx(word_zx(word_of_bits + (\i. i < 8 /\ + bit (32 * i + 31) + (word_join + (word_join + (word_join + (word_sub c7 (word 8380417):int32) + (word_sub c6 (word 8380417):int32) : (64)word) + (word_join + (word_sub c5 (word 8380417):int32) + (word_sub c4 (word 8380417):int32) : (64)word) : (128)word) + (word_join + (word_join + (word_sub c3 (word 8380417):int32) + (word_sub c2 (word 8380417):int32) : (64)word) + (word_join + (word_sub c1 (word 8380417):int32) + (word_sub c0 (word 8380417):int32) : (64)word) : (128)word) + :int256)) :byte) :int32) :int64) = + bitval(val c0 < 8380417) + 2 * bitval(val c1 < 8380417) + + 4 * bitval(val c2 < 8380417) + 8 * bitval(val c3 < 8380417) + + 16 * bitval(val c4 < 8380417) + 32 * bitval(val c5 < 8380417) + + 64 * bitval(val c6 < 8380417) + 128 * bitval(val c7 < 8380417)`, + REPEAT STRIP_TAC THEN + REWRITE_TAC[VMOVMSKPS_BYTE_EQ; BIT_SUBWORD_256] THEN + CONV_TAC(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + ASM_SIMP_TAC[VPSUBD_SIGN_BIT_BOUNDED; SIGN_BIT_BITVAL] THEN + REWRITE_TAC[bitval] THEN + REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + CONV_TAC NUM_REDUCE_CONV);; + +(* Pre-compute the 256 table entry values for VPERMD brute force. + Each entry is an int64 value: 8 bytes from the table at offset 8*mask. *) +let TABLE_ENTRY_VALS = + let table_expanded = + (REWRITE_CONV[mldsa_rej_uniform_table; num_of_wordlist; DIMINDEX_8] THENC + DEPTH_CONV WORD_NUM_RED_CONV THENC NUM_REDUCE_CONV) + `num_of_wordlist mldsa_rej_uniform_table` in + let table_num = rhs(concl table_expanded) in + let entries = Array.init 256 (fun m -> + let tm = mk_comb(mk_comb(`(MOD)`, + mk_comb(mk_comb(`(DIV)`, table_num), + mk_comb(mk_comb(`(EXP)`, `2`), mk_numeral(Num.num_of_int(64*m))))), + mk_comb(mk_comb(`(EXP)`, `2`), `64`)) in + let th = NUM_REDUCE_CONV tm in + let rhs_val = rhs(concl th) in + (* Prove: (num_of_wordlist table DIV 2^(64*m)) MOD 2^64 = entry_m *) + let lhs_tm = mk_comb(mk_comb(`(MOD)`, + mk_comb(mk_comb(`(DIV)`, + `num_of_wordlist mldsa_rej_uniform_table`), + mk_comb(mk_comb(`(EXP)`, `2`), mk_numeral(Num.num_of_int(64*m))))), + mk_comb(mk_comb(`(EXP)`, `2`), `64`)) in + let eq = mk_eq(lhs_tm, rhs_val) in + EQT_ELIM((REWRITE_CONV[table_expanded] THENC NUM_REDUCE_CONV) eq)) in + entries;; + +(* TABLE_ENTRY_FROM_MEMORY: connect bytes64 memory read at table+8k to + (table_num DIV 2^(64k)) MOD 2^64 via bigdigit/bignum_from_memory *) +let TABLE_ENTRY_FROM_MEMORY = prove( + `!table (s:x86state) k. + read(memory :> bytes(table:int64, 2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + k < 256 + ==> val(read(memory :> bytes64(word_add table (word(8 * k)))) s :int64) = + (num_of_wordlist mldsa_rej_uniform_table DIV 2 EXP (64 * k)) MOD 2 EXP 64`, + REPEAT STRIP_TAC THEN + MP_TAC(ISPECL [`256`; `table:int64`; `s:x86state`; `k:num`] + BIGDIGIT_BIGNUM_FROM_MEMORY) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN(SUBST1_TAC o SYM) THEN + REWRITE_TAC[bigdigit] THEN AP_THM_TAC THEN AP_TERM_TAC THEN AP_THM_TAC THEN AP_TERM_TAC THEN + REWRITE_TAC[BIGNUM_FROM_MEMORY_BYTES] THEN CONV_TAC NUM_REDUCE_CONV THEN + ASM_REWRITE_TAC[]);; + +(* TABLE_NUM_THM: expand mldsa_rej_uniform_table to a numeral for table lookup *) +let TABLE_NUM_THM = + (REWRITE_CONV[mldsa_rej_uniform_table; num_of_wordlist; DIMINDEX_8] THENC + DEPTH_CONV WORD_NUM_RED_CONV THENC NUM_REDUCE_CONV) + `num_of_wordlist mldsa_rej_uniform_table`;; + +(* VAL_WORD_GALOIS_64: derive x = word n from val x = n *) +let VAL_WORD_GALOIS_64 = prove( + `!x:int64 n. val x = n /\ n < 18446744073709551616 ==> x = word n`, + REPEAT STRIP_TAC THEN + SUBGOAL_THEN `x:int64 = word(val x)` SUBST1_TAC THENL + [REWRITE_TAC[WORD_VAL]; ASM_REWRITE_TAC[]]);; + +(* VAL_WORD_JOIN8: flatten nested val(word_join^8) to sum of 2^(32*k) * val(ck) *) +let VAL_WORD_JOIN8 = prove( + `!(c0:int32)(c1:int32)(c2:int32)(c3:int32)(c4:int32)(c5:int32)(c6:int32)(c7:int32). + val(word_join + (word_join (word_join c7 c6:(64)word) (word_join c5 c4:(64)word):(128)word) + (word_join (word_join c3 c2:(64)word) (word_join c1 c0:(64)word):(128)word) + :int256) = + val c0 + 2 EXP 32 * val c1 + 2 EXP 64 * val c2 + 2 EXP 96 * val c3 + + 2 EXP 128 * val c4 + 2 EXP 160 * val c5 + 2 EXP 192 * val c6 + 2 EXP 224 * val c7`, + REPEAT GEN_TAC THEN + REWRITE_TAC[VAL_WORD_JOIN; DIMINDEX_32; DIMINDEX_64; DIMINDEX_128] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + MAP_EVERY (fun c -> MP_TAC(ISPEC c VAL_BOUND) THEN REWRITE_TAC[DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV) [`c0:int32`;`c1:int32`;`c2:int32`;`c3:int32`; + `c4:int32`;`c5:int32`;`c6:int32`;`c7:int32`] THEN + REPEAT STRIP_TAC THEN + SUBGOAL_THEN `4294967296 * val(c1:int32) + val(c0:int32) < 18446744073709551616` + (fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `4294967296 * val(c3:int32) + val(c2:int32) < 18446744073709551616` + (fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `4294967296 * val(c5:int32) + val(c4:int32) < 18446744073709551616` + (fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `4294967296 * val(c7:int32) + val(c6:int32) < 18446744073709551616` + (fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `18446744073709551616 * (4294967296 * val(c3:int32) + val(c2:int32)) + + (4294967296 * val(c1:int32) + val(c0:int32)) < + 340282366920938463463374607431768211456` + (fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `18446744073709551616 * (4294967296 * val(c7:int32) + val(c6:int32)) + + (4294967296 * val(c5:int32) + val(c4:int32)) < + 340282366920938463463374607431768211456` + (fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `340282366920938463463374607431768211456 * + (18446744073709551616 * (4294967296 * val(c7:int32) + val(c6:int32)) + + (4294967296 * val(c5:int32) + val(c4:int32))) + + (18446744073709551616 * (4294967296 * val(c3:int32) + val(c2:int32)) + + (4294967296 * val(c1:int32) + val(c0:int32))) < + 115792089237316195423570985008687907853269984665640564039457584007913129639936` + (fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ARITH_TAC);; + +(* MOD_BASE_REWRITES: convert numeral MOD bases to symbolic 2 EXP k *) +let MOD_BASE_REWRITES = [ + GSYM(NUM_REDUCE_CONV `2 EXP 32`); + GSYM(NUM_REDUCE_CONV `2 EXP 64`); + GSYM(NUM_REDUCE_CONV `2 EXP 96`); + GSYM(NUM_REDUCE_CONV `2 EXP 128`); + GSYM(NUM_REDUCE_CONV `2 EXP 160`); + GSYM(NUM_REDUCE_CONV `2 EXP 192`); + GSYM(NUM_REDUCE_CONV `2 EXP 224`); + GSYM(NUM_REDUCE_CONV `2 EXP 256`)];; + +(* VAL_BOUND_256: val(x:int256) < 2 EXP 256 *) +let VAL_BOUND_256 = prove( + `!x:int256. val x < 2 EXP 256`, + GEN_TAC THEN MP_TAC(ISPEC `x:int256` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_256]);; + +(* Factor rules for MOD stripping: rewrite 2^k * x to (2^(k-m) * x) * 2^m *) +let vpermd_factor_for m = List.filter_map (fun k -> + if k >= m && k <= 224 then + let two_exp j = mk_comb(mk_comb(`EXP`,`2`),mk_numeral(Num.num_of_int j)) in + let mul_tm = `( * )` in + let mk_mul a b = mk_comb(mk_comb(mul_tm,a),b) in + Some(ARITH_RULE(mk_eq( + mk_mul (two_exp k) `x:num`, + mk_mul (mk_mul (two_exp (k-m)) `x:num`) (two_exp m)))) + else None) + [32;64;96;128;160;192;224];; + +let VPERMD_FACTOR_RULES = List.map (fun m -> (m, vpermd_factor_for m)) + [32;64;96;128;160;192;224];; + +(* VPERMD_FACTOR_STRIP_TAC: detect MOD base, apply factor rules, strip, MOD_LT *) +let VPERMD_FACTOR_STRIP_TAC : tactic = fun (asl, w) -> + let base = try + let mod_term = rand(lhand w) in + Num.int_of_num(dest_numeral(rand mod_term)) + with _ -> 0 in + let gk = try List.assoc base VPERMD_FACTOR_RULES with Not_found -> [] in + (if gk = [] then ALL_TAC + else + REWRITE_TAC gk THEN + TRY(ONCE_REWRITE_TAC[ARITH_RULE `a+b+c+d+e+f+g+h = (a+b+c+d+e+f+g)+h`] THEN REWRITE_TAC[MOD_MULT_ADD]) THEN + TRY(ONCE_REWRITE_TAC[ARITH_RULE `a+b+c+d+e+f+g = (a+b+c+d+e+f)+g`] THEN REWRITE_TAC[MOD_MULT_ADD]) THEN + TRY(ONCE_REWRITE_TAC[ARITH_RULE `a+b+c+d+e+f = (a+b+c+d+e)+f`] THEN REWRITE_TAC[MOD_MULT_ADD]) THEN + TRY(ONCE_REWRITE_TAC[ARITH_RULE `a+b+c+d+e = (a+b+c+d)+e`] THEN REWRITE_TAC[MOD_MULT_ADD]) THEN + TRY(ONCE_REWRITE_TAC[ARITH_RULE `a+b+c+d = (a+b+c)+d`] THEN REWRITE_TAC[MOD_MULT_ADD]) THEN + TRY(ONCE_REWRITE_TAC[ARITH_RULE `a+b+c = (a+b)+c`] THEN REWRITE_TAC[MOD_MULT_ADD]) THEN + TRY(REWRITE_TAC[MOD_MULT_ADD]) THEN + TRY(MATCH_MP_TAC MOD_LT THEN + REWRITE_TAC[MULT_CLAUSES] THEN + RULE_ASSUM_TAC(REWRITE_RULE[DIMINDEX_32]) THEN ASM_ARITH_TAC)) + (asl, w);; + +(* VPERMD_TABLE_CORRECT: 256-case brute force proof that VPERMD with the mldsa + table correctly compacts coefficients matching FILTER. + Preconditions: 8 coefficients bounded < 2^23, table entry for the mask. + Conclusion: val(VPERMD result) MOD 2^(32*popcount) = num_of_wordlist(FILTER ...) *) +let VPERMD_TABLE_CORRECT = time prove( + `!(c0:int32) (c1:int32) (c2:int32) (c3:int32) (c4:int32) (c5:int32) (c6:int32) (c7:int32) (te:int64). + val c0 < 8388608 /\ val c1 < 8388608 /\ val c2 < 8388608 /\ val c3 < 8388608 /\ + val c4 < 8388608 /\ val c5 < 8388608 /\ val c6 < 8388608 /\ val c7 < 8388608 /\ + val te = (num_of_wordlist mldsa_rej_uniform_table DIV + 2 EXP (64 * (bitval(val c0 < 8380417) + 2 * bitval(val c1 < 8380417) + + 4 * bitval(val c2 < 8380417) + 8 * bitval(val c3 < 8380417) + + 16 * bitval(val c4 < 8380417) + 32 * bitval(val c5 < 8380417) + + 64 * bitval(val c6 < 8380417) + 128 * bitval(val c7 < 8380417)))) + MOD 2 EXP 64 + ==> + let coeffs = word_join + (word_join (word_join c7 c6 :(64)word) (word_join c5 c4 :(64)word) :(128)word) + (word_join (word_join c3 c2 :(64)word) (word_join c1 c0 :(64)word) :(128)word) :int256 in + let ix = word_join + (word_join (word_join (word_zx(word_subword te (56,8):byte):int32) + (word_zx(word_subword te (48,8):byte):int32) :(64)word) + (word_join (word_zx(word_subword te (40,8):byte):int32) + (word_zx(word_subword te (32,8):byte):int32) :(64)word) :(128)word) + (word_join (word_join (word_zx(word_subword te (24,8):byte):int32) + (word_zx(word_subword te (16,8):byte):int32) :(64)word) + (word_join (word_zx(word_subword te (8,8):byte):int32) + (word_zx(word_subword te (0,8):byte):int32) :(64)word) :(128)word) :int256 in + let res = word_join + (word_join (word_join (word_subword coeffs (32 * val(word_subword ix (224,3):(3)word), 32) :int32) + (word_subword coeffs (32 * val(word_subword ix (192,3):(3)word), 32) :int32) :(64)word) + (word_join (word_subword coeffs (32 * val(word_subword ix (160,3):(3)word), 32) :int32) + (word_subword coeffs (32 * val(word_subword ix (128,3):(3)word), 32) :int32) :(64)word) :(128)word) + (word_join (word_join (word_subword coeffs (32 * val(word_subword ix (96,3):(3)word), 32) :int32) + (word_subword coeffs (32 * val(word_subword ix (64,3):(3)word), 32) :int32) :(64)word) + (word_join (word_subword coeffs (32 * val(word_subword ix (32,3):(3)word), 32) :int32) + (word_subword coeffs (32 * val(word_subword ix (0,3):(3)word), 32) :int32) :(64)word) :(128)word) :int256 in + val res MOD 2 EXP (32 * (bitval(val c0 < 8380417) + bitval(val c1 < 8380417) + + bitval(val c2 < 8380417) + bitval(val c3 < 8380417) + + bitval(val c4 < 8380417) + bitval(val c5 < 8380417) + + bitval(val c6 < 8380417) + bitval(val c7 < 8380417))) = + num_of_wordlist(FILTER (\c:int32. val c < 8380417) [c0;c1;c2;c3;c4;c5;c6;c7])`, + REPEAT GEN_TAC THEN STRIP_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + FIRST_X_ASSUM MP_TAC THEN + MAP_EVERY ASM_CASES_TAC + [`val(c0:int32) < 8380417`; `val(c1:int32) < 8380417`; + `val(c2:int32) < 8380417`; `val(c3:int32) < 8380417`; + `val(c4:int32) < 8380417`; `val(c5:int32) < 8380417`; + `val(c6:int32) < 8380417`; `val(c7:int32) < 8380417`] THEN + ASM_REWRITE_TAC[bitval] THEN + CONV_TAC(LAND_CONV(RAND_CONV(REWRITE_CONV[TABLE_NUM_THM] THENC NUM_REDUCE_CONV))) THEN + DISCH_THEN(fun th -> + let n = rhs(concl th) in + SUBST_ALL_TAC(MATCH_MP VAL_WORD_GALOIS_64 + (CONJ th (EQT_ELIM(NUM_REDUCE_CONV(mk_comb(mk_comb(`(<)`,n), `18446744073709551616`))))))) THEN + CONV_TAC(DEPTH_CONV(WORD_NUM_RED_CONV ORELSEC WORD_SIMPLE_SUBWORD_CONV)) THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[FILTER] THEN ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + TRY(REWRITE_TAC[MOD_1; num_of_wordlist] THEN REFL_TAC) THEN + REWRITE_TAC MOD_BASE_REWRITES THEN + TRY(SIMP_TAC[MOD_LT; VAL_BOUND_256]) THEN + REWRITE_TAC[VAL_WORD_JOIN8] THEN + CONV_TAC(RAND_CONV(REWRITE_CONV[num_of_wordlist; ADD_0; DIMINDEX_32; + LEFT_ADD_DISTRIB; MULT_CLAUSES; MULT_ASSOC; GSYM(SPEC `2` EXP_ADD)] THENC + DEPTH_CONV NUM_ADD_CONV)) THEN + TRY REFL_TAC THEN + VPERMD_FACTOR_STRIP_TAC);; + +(* RESOLVE_TABLE_READ_TAC: resolve read(bytes64(word_add table (word K))) terms + in the goal using TABLE_ENTRY_FROM_MEMORY + the memory-table hypothesis *) +let RESOLVE_TABLE_READ_TAC : tactic = fun (asl,w) -> + let mem_hyps = List.filter_map (fun (_,th) -> + if is_eq(concl th) && + (try let c = string_of_term(concl th) in + let _ = String.index c '2' in + String.length c > 60 && + can (find_term (fun t -> try fst(dest_const t) = "num_of_wordlist" with _ -> false)) (concl th) && + can (find_term (fun t -> try dest_numeral t = Num.num_of_int 2048 with _ -> false)) (concl th) + with _ -> false) + then Some th else None) asl in + if mem_hyps = [] then ALL_TAC (asl,w) else + let reads = find_terms (fun t -> + try let _ = find_term (fun s -> try fst(dest_const s) = "bytes64" with _ -> false) t in + let _ = find_term (fun s -> try fst(dest_const s) = "word_add" with _ -> false) t in + fst(dest_const(fst(strip_comb t))) = "read" && + is_comb t && is_var(rand t) + with _ -> false) w in + let eqs = List.filter_map (fun rd -> + try + let state = rand rd in + (* rd = read (memory :> bytes64(word_add table (word K))) sNN + rator rd = read (memory :> bytes64(word_add table (word K))) + rand(rator rd) = memory :> bytes64(word_add table (word K)) + rand(rand(rator rd)) = bytes64(word_add table (word K)) + rand(rand(rand(rator rd))) = word_add table (word K) *) + let word_add_tm = rand(rand(rand(rator rd))) in + let k_tm = rand(rand word_add_tm) in (* K : num *) + let k = Num.int_of_num(dest_numeral k_tm) in + let mask = k / 8 in + let table_var = rand(rator word_add_tm) in + (* Find memory hypothesis for this state *) + let mem_th = try List.find (fun th -> + try rand(rator(lhs(concl th))) = state with _ -> false) mem_hyps + with Not_found -> List.hd mem_hyps in + let spec = SPECL [table_var; state; mk_numeral(Num.num_of_int mask)] + TABLE_ENTRY_FROM_MEMORY in + let prem_th = CONJ mem_th + (EQT_ELIM(NUM_REDUCE_CONV(mk_comb(mk_comb(`(<)`,mk_numeral(Num.num_of_int mask)), `256`)))) in + let val_eq = MP spec prem_th in + let val_eq' = CONV_RULE(RAND_CONV(REWRITE_CONV[TABLE_NUM_THM] THENC NUM_REDUCE_CONV)) val_eq in + (* Also reduce 8*mask in the LHS to match the goal's concrete address *) + let val_eq'' = CONV_RULE(LAND_CONV(DEPTH_CONV NUM_REDUCE_CONV)) val_eq' in + let n = rhs(concl val_eq'') in + Some(MATCH_MP VAL_WORD_GALOIS_64 + (CONJ val_eq'' (EQT_ELIM(NUM_REDUCE_CONV + (mk_comb(mk_comb(`(<)`,n), `18446744073709551616`)))))) + with _ -> None) reads in + if eqs = [] then ALL_TAC (asl,w) + else REWRITE_TAC eqs (asl,w);; + +(* VPERMD_MEMORY_BRIDGE: connect a sub-read of the 32-byte VMOVDQU write + region to the VPERMD MOD result, closing the memory store goal. *) +let VPERMD_MEMORY_BRIDGE = prove + (`!a (s:x86state) vr n l. + read(memory :> bytes(a:int64, 32)) s = vr /\ + vr MOD 2 EXP (32 * n) = num_of_wordlist(l:int32 list) /\ + n <= 8 + ==> read(memory :> bytes(a, 4 * n)) s = num_of_wordlist l`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN + `read(memory :> bytes(a:int64, 4 * n)) s = + read(memory :> bytes(a, 32)) s MOD 2 EXP (8 * (4 * n))` + SUBST1_TAC THENL + [REWRITE_TAC[READ_COMPONENT_COMPOSE; GSYM READ_BYTES_MOD] THEN + GEN_REWRITE_TAC (RAND_CONV o ONCE_DEPTH_CONV) + [GSYM(NUM_REDUCE_CONV `8 * 32`)] THEN + REWRITE_TAC[READ_BYTES_MOD] THEN + SUBGOAL_THEN `MIN 32 (4 * n) = 4 * n` SUBST1_TAC THENL + [REWRITE_TAC[MIN] THEN ASM_ARITH_TAC; + REFL_TAC]; + ASM_REWRITE_TAC[ARITH_RULE `8 * 4 * n = 32 * n`]]);; + +(* VAL_READ_BYTES256: val(read(bytes256 addr) s) = read(bytes(addr,32)) s + Converts a 256-bit word read to a numeric bytes read. *) +let VAL_READ_BYTES256 = prove( + `!addr (s:(int64->byte)). + val(read(bytes256 addr) s :int256) = read(bytes(addr,32)) s`, + REWRITE_TAC[BYTES256_WBYTES; VAL_READ_WBYTES; DIMINDEX_256] THEN + CONV_TAC NUM_REDUCE_CONV);; + +(* ========================================================================= *) +(* Post-helper lemmas. *) +(* ========================================================================= *) + +(* Remaining helper lemmas from the proof file *) + +let DIMINDEX_23 = DIMINDEX_CONV `dimindex(:23)`;; +let DIMINDEX_24 = DIMINDEX_CONV `dimindex(:24)`;; + +let VAL_MOD_23_EQ_AND = prove + (`!w:24 word. (word(val w MOD 2 EXP 23):int32) = + word_and (word_zx w:int32) (word 8388607)`, + GEN_TAC THEN CONV_TAC WORD_BLAST);; + +let REJ_SAMPLE = define + `REJ_SAMPLE l = FILTER (\x:int32. val x < 8380417) + (MAP (\x:24 word. word(val x MOD 2 EXP 23):int32) l)`;; + +let REJ_SAMPLE_EMPTY = prove + (`REJ_SAMPLE [] = []`, + REWRITE_TAC[REJ_SAMPLE; FILTER; MAP]);; + +let REJ_SAMPLE_APPEND = prove + (`!l1 l2. REJ_SAMPLE(APPEND l1 l2) = + APPEND (REJ_SAMPLE l1) (REJ_SAMPLE l2)`, + REWRITE_TAC[REJ_SAMPLE; MAP_APPEND; FILTER_APPEND]);; + +let mldsa_mask_lemma = prove + ((rand o concl o (EXPAND_CASES_CONV THENC NUM_REDUCE_CONV)) + `!i. i < 8 + ==> word_and (word_subword (q:int256) (32*i,32)) (word 8388607):int32 = + word_zx(word_subword (q:int256) (32*i,23):23 word)`, + CONV_TAC WORD_BLAST);; + +let VAL_WORD_ZX_23 = prove + (`!w:23 word. val(word_zx w:int32) < 8388608`, + GEN_TAC THEN REWRITE_TAC[VAL_WORD_ZX_GEN; DIMINDEX_23; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `w:23 word` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_23] THEN CONV_TAC NUM_REDUCE_CONV THEN + DISCH_TAC THEN + SUBGOAL_THEN `val(w:23 word) MOD 4294967296 = val w` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ASM_ARITH_TAC]);; + +let COEFF_FROM_BYTES = prove + ((rand o concl o (EXPAND_CASES_CONV THENC NUM_REDUCE_CONV)) + `!j. j < 8 ==> + word_and (word_zx(word_subword (buf:192 word) (24*j,24):24 word):int32) + (word 8388607) = + word_zx(word_subword (buf:192 word) (24*j,23):23 word)`, + CONV_TAC WORD_BLAST);; + + +(* ========================================================================= *) +(* REJ_SAMPLE algebra. *) +(* ========================================================================= *) + +(* Lemmas that defs.ml / step*.ml need but which weren't in the checkpoint loader. + Extracted verbatim from mldsa_rej_uniform.ml. Load before defs.ml. *) + +(* POPCNT of VMOVMSKPS sign-bit mask = LENGTH(FILTER) — 256-case brute force *) +let POPCNT_EQ_LENGTH_FILTER = prove + (`!x0 x1 x2 x3 x4 x5 x6 x7:int32. + val x0 < 8388608 /\ val x1 < 8388608 /\ val x2 < 8388608 /\ val x3 < 8388608 /\ + val x4 < 8388608 /\ val x5 < 8388608 /\ val x6 < 8388608 /\ val x7 < 8388608 + ==> word_popcount(word( + bitval(bit 31 (word_sub x0 (word 8380417):int32)) + + 2 * bitval(bit 31 (word_sub x1 (word 8380417):int32)) + + 4 * bitval(bit 31 (word_sub x2 (word 8380417):int32)) + + 8 * bitval(bit 31 (word_sub x3 (word 8380417):int32)) + + 16 * bitval(bit 31 (word_sub x4 (word 8380417):int32)) + + 32 * bitval(bit 31 (word_sub x5 (word 8380417):int32)) + + 64 * bitval(bit 31 (word_sub x6 (word 8380417):int32)) + + 128 * bitval(bit 31 (word_sub x7 (word 8380417):int32))):byte) = + LENGTH(FILTER (\x:int32. val x < 8380417) [x0;x1;x2;x3;x4;x5;x6;x7])`, + REPEAT STRIP_TAC THEN + REPEAT(FIRST_X_ASSUM(fun th -> + try let th' = MATCH_MP SIGN_BIT_BITVAL th in REWRITE_TAC[th'] with _ -> ASSUME_TAC th)) THEN + MAP_EVERY ASM_CASES_TAC + [`val(x0:int32) < 8380417`; `val(x1:int32) < 8380417`; + `val(x2:int32) < 8380417`; `val(x3:int32) < 8380417`; + `val(x4:int32) < 8380417`; `val(x5:int32) < 8380417`; + `val(x6:int32) < 8380417`; `val(x7:int32) < 8380417`] THEN + ASM_REWRITE_TAC[FILTER; bitval; LENGTH] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV));; + +(* LENGTH(FILTER) = sum of bitvals — the 256-case brute force *) +let FILTER_LENGTH_8 = prove + (`!x0 x1 x2 x3 x4 x5 x6 x7:int32. + val x0 < 8388608 /\ val x1 < 8388608 /\ val x2 < 8388608 /\ val x3 < 8388608 /\ + val x4 < 8388608 /\ val x5 < 8388608 /\ val x6 < 8388608 /\ val x7 < 8388608 + ==> LENGTH(FILTER (\x. val x < 8380417) [x0;x1;x2;x3;x4;x5;x6;x7]) = + bitval(val x0 < 8380417) + bitval(val x1 < 8380417) + + bitval(val x2 < 8380417) + bitval(val x3 < 8380417) + + bitval(val x4 < 8380417) + bitval(val x5 < 8380417) + + bitval(val x6 < 8380417) + bitval(val x7 < 8380417)`, + REPEAT STRIP_TAC THEN + MAP_EVERY ASM_CASES_TAC + [`val(x0:int32) < 8380417`; `val(x1:int32) < 8380417`; + `val(x2:int32) < 8380417`; `val(x3:int32) < 8380417`; + `val(x4:int32) < 8380417`; `val(x5:int32) < 8380417`; + `val(x6:int32) < 8380417`; `val(x7:int32) < 8380417`] THEN + ASM_REWRITE_TAC[FILTER; LENGTH; bitval] THEN + CONV_TAC NUM_REDUCE_CONV);; + +(* VMOVMSKPS sign bits + POPCNT = LENGTH(FILTER) for 8 dword lanes *) +let POPCNT_VMOVMSKPS_LEMMA = prove + (`!q:int256. + word_popcount(word( + bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (0,32):int32) (word 8380417):int32)) + + 2 * bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (32,32):int32) (word 8380417):int32)) + + 4 * bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (64,32):int32) (word 8380417):int32)) + + 8 * bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (96,32):int32) (word 8380417):int32)) + + 16 * bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (128,32):int32) (word 8380417):int32)) + + 32 * bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (160,32):int32) (word 8380417):int32)) + + 64 * bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (192,32):int32) (word 8380417):int32)) + + 128 * bitval(bit 31 (word_sub (word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (224,32):int32) (word 8380417):int32))):byte) = + LENGTH(FILTER (\c:int32. val c < 8380417) + [word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (0,32):int32; + word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (32,32); + word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (64,32); + word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (96,32); + word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (128,32); + word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (160,32); + word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (192,32); + word_subword (word_and q (word 226156397384342666605459106258636701594091082888230722833791023177481060351):int256) (224,32)])`, + GEN_TAC THEN REWRITE_TAC[WORD_SUBWORD_AND] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REWRITE_TAC[mldsa_mask_lemma] THEN + MATCH_MP_TAC POPCNT_EQ_LENGTH_FILTER THEN + REWRITE_TAC[VAL_WORD_ZX_23]);; + +(* Full iteration bridge: split, length, and bound *) +let SIMD_ITERATION_BRIDGE = prove + (`!inlist:(24 word)list i curlist curlen. + REJ_SAMPLE(SUB_LIST(0,8*i) inlist) = curlist /\ + LENGTH curlist = curlen /\ + 8*(i+1) <= LENGTH inlist + ==> REJ_SAMPLE(SUB_LIST(0,8*(i+1)) inlist) = + APPEND curlist (REJ_SAMPLE(SUB_LIST(8*i,8) inlist)) /\ + LENGTH(REJ_SAMPLE(SUB_LIST(0,8*(i+1)) inlist)) = + curlen + LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) inlist)) /\ + LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) inlist)) <= 8`, + REPEAT STRIP_TAC THENL + [REWRITE_TAC[ARITH_RULE `8*(i+1) = 8*i + 8`] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8*i`; `8`; `0`] SUB_LIST_SPLIT) THEN + REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN SUBST1_TAC THEN + ASM_REWRITE_TAC[REJ_SAMPLE_APPEND]; + REWRITE_TAC[ARITH_RULE `8*(i+1) = 8*i + 8`] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8*i`; `8`; `0`] SUB_LIST_SPLIT) THEN + REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN SUBST1_TAC THEN + ASM_REWRITE_TAC[REJ_SAMPLE_APPEND; LENGTH_APPEND]; + REWRITE_TAC[REJ_SAMPLE] THEN + W(MP_TAC o PART_MATCH lhand LENGTH_FILTER o lhand o snd) THEN + MATCH_MP_TAC(REWRITE_RULE[IMP_CONJ_ALT] LE_TRANS) THEN + REWRITE_TAC[LENGTH_MAP; LENGTH_SUB_LIST] THEN ARITH_TAC]);; + +(* word_join of 8 consecutive 32-bit subwords reconstructs the original 256-bit word. + Used by the VPERMD bridge to fold the VPERMD expression back to coeffs_ymm3. *) +let WORD_JOIN_SUBWORDS_256 = prove + (`!q:int256. + word_join + (word_join (word_join ((word_subword q (224,32)):int32) ((word_subword q (192,32)):int32):int64) + (word_join ((word_subword q (160,32)):int32) ((word_subword q (128,32)):int32):int64):int128) + (word_join (word_join ((word_subword q (96,32)):int32) ((word_subword q (64,32)):int32):int64) + (word_join ((word_subword q (32,32)):int32) ((word_subword q (0,32)):int32):int64):int128):int256 = q`, + GEN_TAC THEN CONV_TAC WORD_BLAST);; + +(* Standalone VPERMD bridge: given 8 bounds on subwords of q and the table lookup + value of te, the VPERMD expansion of (q, te) mod 2^(32*popcount) equals + num_of_wordlist(FILTER (val + val(word_join + (word_join + (word_join ((word_subword q (32 * val(word_subword te (56,3):3 word), 32)):int32) + ((word_subword q (32 * val(word_subword te (48,3):3 word), 32)):int32):int64) + (word_join ((word_subword q (32 * val(word_subword te (40,3):3 word), 32)):int32) + ((word_subword q (32 * val(word_subword te (32,3):3 word), 32)):int32):int64):int128) + (word_join + (word_join ((word_subword q (32 * val(word_subword te (24,3):3 word), 32)):int32) + ((word_subword q (32 * val(word_subword te (16,3):3 word), 32)):int32):int64) + (word_join ((word_subword q (32 * val(word_subword te (8,3):3 word), 32)):int32) + ((word_subword q (32 * val(word_subword te (0,3):3 word), 32)):int32):int64):int128):int256) MOD + 2 EXP (32 * (bitval(val(word_subword q (0,32):int32) < 8380417) + + bitval(val(word_subword q (32,32):int32) < 8380417) + + bitval(val(word_subword q (64,32):int32) < 8380417) + + bitval(val(word_subword q (96,32):int32) < 8380417) + + bitval(val(word_subword q (128,32):int32) < 8380417) + + bitval(val(word_subword q (160,32):int32) < 8380417) + + bitval(val(word_subword q (192,32):int32) < 8380417) + + bitval(val(word_subword q (224,32):int32) < 8380417))) = + num_of_wordlist(FILTER (\c:int32. val c < 8380417) + [word_subword q (0,32); word_subword q (32,32); + word_subword q (64,32); word_subword q (96,32); + word_subword q (128,32); word_subword q (160,32); + word_subword q (192,32); word_subword q (224,32)])`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + MP_TAC(ISPECL [ + `word_subword (q:int256) (0,32):int32`; + `word_subword (q:int256) (32,32):int32`; + `word_subword (q:int256) (64,32):int32`; + `word_subword (q:int256) (96,32):int32`; + `word_subword (q:int256) (128,32):int32`; + `word_subword (q:int256) (160,32):int32`; + `word_subword (q:int256) (192,32):int32`; + `word_subword (q:int256) (224,32):int32`; + `te:int64` + ] VPERMD_TABLE_CORRECT) THEN + ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[WORD_JOIN_SUBWORDS_256] THEN + CONV_TAC(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + DISCH_THEN ACCEPT_TAC);; + +(* ------------------------------------------------------------------------- *) +(* REJ_SAMPLE list decomposition helpers for the post-loop proof. *) +(* ------------------------------------------------------------------------- *) + +(* REJ_SAMPLE of a list is APPEND of REJ_SAMPLE of a prefix and a suffix. *) +let REJ_SAMPLE_SPLIT = prove + (`!(l:(24 word)list) n. + REJ_SAMPLE l = APPEND (REJ_SAMPLE (SUB_LIST (0,n) l)) + (REJ_SAMPLE (SUB_LIST (n, LENGTH l - n) l))`, + REPEAT GEN_TAC THEN REWRITE_TAC[GSYM REJ_SAMPLE_APPEND] THEN + MESON_TAC[SUB_LIST_TOPSPLIT]);; + +(* If a prefix's REJ_SAMPLE has length 256, then the first 256 of REJ_SAMPLE + of the full list equals REJ_SAMPLE of that prefix. Used in the post-loop + JAE-exit case to conclude outlist = SUB_LIST (0,256) (REJ_SAMPLE inlist). *) +let REJ_SAMPLE_PREFIX_256 = prove + (`!(inlist:(24 word)list) k. + LENGTH (REJ_SAMPLE (SUB_LIST (0,k) inlist)) = 256 + ==> SUB_LIST (0,256) (REJ_SAMPLE inlist) = REJ_SAMPLE (SUB_LIST (0,k) inlist)`, + REPEAT STRIP_TAC THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `k:num`] REJ_SAMPLE_SPLIT) THEN + DISCH_THEN(fun th -> GEN_REWRITE_TAC (LAND_CONV o RAND_CONV) [th]) THEN + MP_TAC(ISPECL + [`REJ_SAMPLE (SUB_LIST (0,k) (inlist:(24 word)list))`; + `REJ_SAMPLE (SUB_LIST (k, LENGTH inlist - k) (inlist:(24 word)list))`; + `256`] SUB_LIST_APPEND_LEFT) THEN + ANTS_TAC THENL [ASM_REWRITE_TAC[LE_REFL]; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + MATCH_MP_TAC SUB_LIST_REFL THEN + ASM_REWRITE_TAC[LE_REFL]);; + +(* Monotonicity: one more input element adds at most 1 to REJ_SAMPLE length. *) +let REJ_SAMPLE_STEP_LE = prove + (`!(l:(24 word)list) k. + LENGTH (REJ_SAMPLE (SUB_LIST (0, k + 1) l)) <= + LENGTH (REJ_SAMPLE (SUB_LIST (0, k) l)) + 1`, + REPEAT GEN_TAC THEN + ASM_CASES_TAC `k + 1 <= LENGTH (l:(24 word)list)` THENL + [MP_TAC(ISPECL [`l:(24 word)list`; `k:num`; `1:num`; `0:num`] SUB_LIST_SPLIT) THEN + REWRITE_TAC[ADD_CLAUSES] THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[REJ_SAMPLE_APPEND; LENGTH_APPEND; LE_ADD_LCANCEL] THEN + REWRITE_TAC[REJ_SAMPLE] THEN + W(MP_TAC o PART_MATCH lhand LENGTH_FILTER o lhand o snd) THEN + MATCH_MP_TAC(REWRITE_RULE[IMP_CONJ_ALT] LE_TRANS) THEN + REWRITE_TAC[LENGTH_MAP; LENGTH_SUB_LIST] THEN ARITH_TAC; + SUBGOAL_THEN `SUB_LIST (0, k + 1) (l:(24 word)list) = l /\ + LENGTH (l:(24 word)list) <= k` + (fun th -> SUBST1_TAC(CONJUNCT1 th) THEN + ASM_SIMP_TAC[SUB_LIST_REFL; CONJUNCT2 th] THEN ARITH_TAC) THEN + CONJ_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN ASM_ARITH_TAC; + ASM_ARITH_TAC]]);; + +(* ========================================================================= *) +(* R9 bridge + JA resolvers. *) +(* ========================================================================= *) + +(* JA branch resolution tactics from the proof file *) +let RESOLVE_JA_ONLY_TAC svar = + fun th -> + TRY(FIRST_X_ASSUM(K ALL_TAC o check (fun th' -> + let c = concl th' in + is_eq c && can (find_term is_cond) c && + can (find_term ((=) svar)) c && + can (find_term ((=) `RIP`)) c))) THEN + ASSUME_TAC th;; + +let RESOLVE_JA_CURLEN_TAC = + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + can (find_term ((=) `RIP`)) (concl th) && is_eq(concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then a else b) = b`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[REAL_EQ_SUB_RADD; REAL_OF_NUM_ADD; REAL_OF_NUM_EQ] THEN + UNDISCH_TAC `curlen <= 248` THEN ARITH_TAC;; + +let RESOLVE_JA_OFFSET_TAC = + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + can (find_term ((=) `RIP`)) (concl th) && is_eq(concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then a else b) = b`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[REAL_EQ_SUB_RADD; REAL_OF_NUM_ADD; REAL_OF_NUM_EQ] THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC;; + +(* ========================================================================= *) +(* PIVOT_VAL_EQ lemma. *) +(* ========================================================================= *) +(* SCALAR_BODY_LEMMA preamble (byte bridges + ACCEPT_REJ_SAMPLE_SINGLETON). *) +(* ========================================================================= *) + +let READ_3BYTES_EL = prove + (`!(inlist:(24 word)list) (buf:int64) mem j n. + LENGTH inlist = n /\ j < n /\ 3 * j + 3 <= 3 * n /\ + read(memory :> bytes(buf, 3 * n)) mem = num_of_wordlist inlist + ==> read(memory :> bytes(word_add buf (word(3 * j)), 3)) mem = + val(EL j inlist)`, + REPEAT STRIP_TAC THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `j:num`] NUM_OF_WORDLIST_EL) THEN + ASM_REWRITE_TAC[DIMINDEX_24] THEN DISCH_THEN(SUBST1_TAC o SYM) THEN + POP_ASSUM MP_TAC THEN DISCH_THEN(SUBST1_TAC o SYM) THEN + REWRITE_TAC[READ_COMPONENT_COMPOSE] THEN + SUBGOAL_THEN + `read (bytes (buf,3 * n)) (read memory mem) DIV 2 EXP (24 * j) = + read (bytes (word_add buf (word (3*j)), 3 * n - 3*j)) (read memory mem)` + SUBST1_TAC THENL + [REWRITE_TAC[READ_BYTES_DIV; ARITH_RULE `24 * j = 8 * (3 * j)`] THEN + REFL_TAC; + ALL_TAC] THEN + REWRITE_TAC[ARITH_RULE `24 = 8 * 3`; READ_BYTES_MOD] THEN + SUBGOAL_THEN `MIN (3 * n - 3 * j) 3 = 3` SUBST1_TAC THENL + [UNDISCH_TAC `3 * j + 3 <= 3 * n` THEN REWRITE_TAC[MIN] THEN ARITH_TAC; + REFL_TAC]);; + +(* Byte-to-coefficient bridge: 3 bytes of memory, mixed via bytes16 + bytes8 + + word_or (as the AVX2 scalar path does), equal val(EL j inlist). This is + the semantic heart of the filter-rejection reasoning in the scalar body. *) +let BYTE_BRIDGE_3BYTES = prove + (`!(inlist:(24 word)list) (buf:int64) s j n. + LENGTH inlist = n /\ j < n /\ 3 * j + 3 <= 3 * n /\ + read(memory :> bytes(buf, 3 * n)) s = num_of_wordlist inlist + ==> + val(word_or + (word_zx(read(memory :> bytes16 (word_add buf (word (3*j)))) s):(32)word) + (word_shl + (word_zx(read(memory :> bytes8 (word_add buf (word(3*j + 2)))) s):(32)word) + 16):(32)word):num + = val(EL j inlist)`, + REPEAT STRIP_TAC THEN + MP_TAC(SPECL [`inlist:(24 word)list`; `buf:int64`; `s:x86state`; `j:num`; `n:num`] + READ_3BYTES_EL) THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `(3:num) = 2 + 1` SUBST1_TAC THENL [ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[READ_COMPONENT_COMPOSE; READ_BYTES_COMBINE] THEN + REWRITE_TAC[bytes16; bytes8; READ_COMPONENT_COMPOSE; asword; through; read] THEN + ABBREV_TAC + `a = read (bytes (word_add buf (word ((2 + 1) * j)),2)) (read memory s)` THEN + ABBREV_TAC + `b = read (bytes (word_add buf (word ((2 + 1) * j + 2)),1)) (read memory s)` THEN + SUBGOAL_THEN + `word_add (word_add buf (word((2+1)*j))) (word 2):int64 = + word_add buf (word ((2+1)*j + 2))` SUBST_ALL_TAC THENL + [CONV_TAC WORD_RULE; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPECL [`word_add buf (word ((2 + 1) * j)):int64`; `2`; + `read memory s:int64->(8)word`] READ_BYTES_BOUND) THEN + MP_TAC(ISPECL [`word_add buf (word ((2 + 1) * j + 2)):int64`; `1`; + `read memory s:int64->(8)word`] READ_BYTES_BOUND) THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT DISCH_TAC THEN + MP_TAC(ISPECL [`word_zx(word a:(16)word):(32)word`; + `word_shl(word_zx(word b:(8)word):(32)word) 16`] + VAL_WORD_OR_DISJOINT) THEN + ANTS_TAC THENL [CONV_TAC WORD_BLAST; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[VAL_WORD_SHL; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_16; DIMINDEX_8] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `a MOD 65536 = a` SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT]; ALL_TAC] THEN + SUBGOAL_THEN `b MOD 256 = b` SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT]; ALL_TAC] THEN + SUBGOAL_THEN `a MOD 4294967296 = a` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `a < 65536` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `b MOD 4294967296 = b` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `b < 256` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `(65536 * b) MOD 4294967296 = 65536 * b` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `b < 256` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[]);; + +(* Two-state variant: the bytes16 and bytes8 reads can come from different + states as long as both states have the same num_of_wordlist read at buf. *) +let BYTE_BRIDGE_3BYTES_2STATE = prove + (`!(inlist:(24 word)list) (buf:int64) (s1:x86state) (s2:x86state) j n. + LENGTH inlist = n /\ j < n /\ 3 * j + 3 <= 3 * n /\ + read(memory :> bytes(buf, 3 * n)) s1 = num_of_wordlist inlist /\ + read(memory :> bytes(buf, 3 * n)) s2 = num_of_wordlist inlist + ==> + val(word_or + (word_zx(read(memory :> bytes16 (word_add buf (word (3*j)))) s1):(32)word) + (word_shl + (word_zx(read(memory :> bytes8 (word_add buf (word(3*j + 2)))) s2):(32)word) + 16):(32)word):num + = val(EL j inlist)`, + REPEAT STRIP_TAC THEN + SUBGOAL_THEN + `read(memory :> bytes8 (word_add buf (word (3*j + 2)):int64)) s2 = + read(memory :> bytes8 (word_add buf (word (3*j + 2)):int64)) s1` + SUBST1_TAC THENL + [REWRITE_TAC[bytes8; READ_COMPONENT_COMPOSE; asword; through; read] THEN + AP_TERM_TAC THEN REWRITE_TAC[GSYM READ_COMPONENT_COMPOSE] THEN + SUBGOAL_THEN + `!(s:x86state). + read (memory :> bytes (word_add buf (word (3 * j + 2)),1)) s = + (read(memory :> bytes(buf, 3 * n)) s DIV 2 EXP (8 * (3 * j + 2))) MOD + 2 EXP (8 * 1)` + (fun th -> REWRITE_TAC[th]) THENL + [GEN_TAC THEN REWRITE_TAC[READ_COMPONENT_COMPOSE] THEN + REWRITE_TAC[READ_BYTES_DIV; READ_BYTES_MOD] THEN + SUBGOAL_THEN `MIN (3 * n - (3 * j + 2)) 1 = 1` SUBST1_TAC THENL + [UNDISCH_TAC `3 * j + 3 <= 3 * n` THEN REWRITE_TAC[MIN] THEN ARITH_TAC; + REFL_TAC]; + ASM_REWRITE_TAC[]]; + MATCH_MP_TAC BYTE_BRIDGE_3BYTES THEN + EXISTS_TAC `n:num` THEN ASM_REWRITE_TAC[]]);; + +(* Bridge from a bytes32 word-read equation to a bytes(_,4) num-read + equation at the same state. Used in the ACCEPT branch to convert the + MOV store's bytes32 equation at s48 into a bytes(_,4) equation that + VSTEPS can then propagate unchanged through s49 (ADD) and s50 (JMP). *) + +let BYTES32_TO_BYTES = prove + (`!(mem:(x86state,int64->byte)component) (s:x86state) (a:int64) (w:(32)word). + read(mem :> bytes32 a) s = w + ==> read(mem :> bytes(a,4)) s = val w`, + REPEAT GEN_TAC THEN + REWRITE_TAC[bytes32; READ_COMPONENT_COMPOSE; asword; through; read] THEN + ABBREV_TAC + `b = read (bytes (a,4)) + (read (mem:(x86state,int64->byte)component) s)` THEN + DISCH_THEN(MP_TAC o AP_TERM `val:int32->num`) THEN + REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPECL [`a:int64`; `4`; + `(read (mem:(x86state,int64->byte)component) s):int64->(8)word`] + READ_BYTES_BOUND) THEN + ASM_REWRITE_TAC[ARITH_RULE `8 * 4 = 32`] THEN + CONV_TAC NUM_REDUCE_CONV THEN + DISCH_THEN(fun th -> REWRITE_TAC[MATCH_MP MOD_LT th]));; + +(* ACCEPT-branch singleton bridge: given the memory equations and the exact + form of r8val (as it appears after X86_VSTEPS through state s46) with + val r8val < 8380417, derive both: + - the pivot: val r8val = val(EL (8*N+i) inlist) MOD 2^23 + - the filter conclusion: REJ_SAMPLE(SUB_LIST(8*N+i, 1) inlist) = [word(val r8val):int32] + + This packages the pivot + filter into one implication so it can be applied + via MP_TAC without adding the intermediate pivot to the main goal's asl + (which would pollute downstream ASM_REWRITE_TAC rewrites). *) + +let ACCEPT_REJ_SAMPLE_SINGLETON = prove + (`!(inlist:(24 word)list) (buf:int64) (s1:x86state) (s2:x86state) + (r8val:int64) (N:num) (i:num). + LENGTH inlist = 280 /\ + 8 * N + i < 280 /\ + 3 * (8 * N + i) + 3 <= 3 * 280 /\ + read(memory :> bytes(buf, 3 * 280)) s1 = num_of_wordlist inlist /\ + read(memory :> bytes(buf, 3 * 280)) s2 = num_of_wordlist inlist /\ + val(r8val:int64) < 8380417 /\ + r8val = word_zx(word_and + (word_zx (word_zx + (word_or + (word_zx (word_zx (word_zx + (read(memory :> bytes16 (word_add buf (word (3*(8*N+i))))) s1) + :(32)word):(64)word):(32)word) + (word_zx (word_zx + (word_shl + (word_zx (word_zx (word_zx + (read(memory :> bytes8 (word_add buf (word (3*(8*N+i) + 2)))) s2) + :(32)word):(64)word):(32)word) 16) + :(64)word):(32)word) + :(32)word):(64)word):(32)word) + (word 8388607:(32)word):(32)word):int64 + ==> + val r8val = val(EL (8*N+i) inlist) MOD 2 EXP 23 /\ + REJ_SAMPLE(SUB_LIST(8*N + i, 1) inlist) = [word(val r8val):int32]`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `val(r8val:int64) = val(EL (8*N+i) (inlist:(24 word)list)) MOD 2 EXP 23` + ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN + `!(x:(32)word). + word_and (word_zx (word_zx x:(64)word):(32)word) + (word 8388607:(32)word) = + word_and x (word 8388607:(32)word)` + (fun th -> REWRITE_TAC[th]) THENL + [CONV_TAC WORD_BLAST; ALL_TAC] THEN + SUBGOAL_THEN + `!(x:(16)word). + word_zx(word_zx(word_zx x:(32)word):(64)word):(32)word = word_zx x` + (fun th -> REWRITE_TAC[th]) THENL + [CONV_TAC WORD_BLAST; ALL_TAC] THEN + SUBGOAL_THEN + `!(x:(8)word). + word_zx(word_zx(word_shl(word_zx(word_zx(word_zx x:(32)word):(64)word):(32)word) 16):(64)word):(32)word = + word_shl(word_zx x:(32)word) 16` + (fun th -> REWRITE_TAC[th]) THENL + [CONV_TAC WORD_BLAST; ALL_TAC] THEN + SUBGOAL_THEN + `!(w:(32)word). word 8388607:(32)word = word(2 EXP 23 - 1)` + (fun th -> REWRITE_TAC[th]) THENL + [GEN_TAC THEN AP_TERM_TAC THEN CONV_TAC NUM_REDUCE_CONV; ALL_TAC] THEN + REWRITE_TAC[VAL_WORD_AND_MASK_WORD] THEN + SUBGOAL_THEN + `!(w:(32)word). val w MOD 2 EXP 23 MOD 18446744073709551616 = val w MOD 2 EXP 23` + (fun th -> REWRITE_TAC[th]) THENL + [GEN_TAC THEN MATCH_MP_TAC MOD_LT THEN + MP_TAC(ARITH_RULE `!x. x MOD 2 EXP 23 < 8388608`) THEN + DISCH_THEN(MP_TAC o SPEC `val(w:(32)word)`) THEN ARITH_TAC; + ALL_TAC] THEN + MP_TAC(SPECL [`inlist:(24 word)list`; `buf:int64`; `s1:x86state`; + `s2:x86state`; `8*N+i:num`; `280`] BYTE_BRIDGE_3BYTES_2STATE) THEN + ASM_REWRITE_TAC[] THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN CONV_TAC NUM_REDUCE_CONV; + ALL_TAC] THEN + CONJ_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[SUB_LIST_1] THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[REJ_SAMPLE; MAP; FILTER] THEN + SUBGOAL_THEN `val(word (val (EL (8*N+i) (inlist:(24 word)list)) MOD 2 EXP 23):int32) = + val(r8val:int64)` + SUBST1_TAC THENL + [ASM_REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + SUBGOAL_THEN `val (EL (8*N+i) (inlist:(24 word)list)) MOD 2 EXP 23 MOD 2 EXP 32 = + val (EL (8*N+i) inlist) MOD 2 EXP 23` + SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN + MP_TAC(ARITH_RULE `!x. x MOD 2 EXP 23 < 8388608`) THEN + DISCH_THEN(MP_TAC o SPEC `val(EL (8*N+i) (inlist:(24 word)list))`) THEN + ARITH_TAC; + FIRST_X_ASSUM(fun th -> REWRITE_TAC[SYM th])]; + ASM_REWRITE_TAC[]]);; + +(* ========================================================================= *) + +(* PIVOT_VAL_EQ: key pivot lemma for the REJECT branch of scalar_body_lemma. + + Derived from ACCEPT_REJ_SAMPLE_SINGLETON by dropping the `val r8val < 8380417` + premise and returning only the first conjunct. + + Rationale: the inline derivation of this fact in scalar_body_lemma.ml:816-858 + took 40+ minutes because it rewrites 217 x86-state assumptions via + VAL_WORD_ZX_GEN + NUM_REDUCE_CONV. Proving it as a top-level lemma with + WORD_BLAST normalizers runs in ~1s, then MP_TAC/ANTS_TAC inline is ~0.2s. + + Depends on ACCEPT_REJ_SAMPLE_SINGLETON, BYTE_BRIDGE_3BYTES_2STATE (from + scalar_body_preamble.ml). *) + +let stmt = + let accept_concl = concl ACCEPT_REJ_SAMPLE_SINGLETON in + let vars, body = strip_forall accept_concl in + let prems, cncl = dest_imp body in + let prem_list = conjuncts prems in + (* Remove 'val r8val < 8380417' premise (index 5) *) + let new_prems = list_mk_conj (List.filteri (fun n _ -> n <> 5) prem_list) in + let new_concl = fst(dest_conj cncl) in + list_mk_forall (vars, mk_imp(new_prems, new_concl));; + +let PIVOT_VAL_EQ = prove(stmt, + REPEAT GEN_TAC THEN STRIP_TAC THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN + `!(x:(32)word). + word_and (word_zx (word_zx x:(64)word):(32)word) + (word 8388607:(32)word) = + word_and x (word 8388607:(32)word)` + (fun th -> REWRITE_TAC[th]) THENL + [CONV_TAC WORD_BLAST; ALL_TAC] THEN + SUBGOAL_THEN + `!(x:(16)word). + word_zx(word_zx(word_zx x:(32)word):(64)word):(32)word = word_zx x` + (fun th -> REWRITE_TAC[th]) THENL + [CONV_TAC WORD_BLAST; ALL_TAC] THEN + SUBGOAL_THEN + `!(x:(8)word). + word_zx(word_zx(word_shl(word_zx(word_zx(word_zx x:(32)word):(64)word):(32)word) 16):(64)word):(32)word = + word_shl(word_zx x:(32)word) 16` + (fun th -> REWRITE_TAC[th]) THENL + [CONV_TAC WORD_BLAST; ALL_TAC] THEN + SUBGOAL_THEN + `!(w:(32)word). word 8388607:(32)word = word(2 EXP 23 - 1)` + (fun th -> REWRITE_TAC[th]) THENL + [GEN_TAC THEN AP_TERM_TAC THEN CONV_TAC NUM_REDUCE_CONV; ALL_TAC] THEN + REWRITE_TAC[VAL_WORD_AND_MASK_WORD] THEN + SUBGOAL_THEN + `!(w:(32)word). val w MOD 2 EXP 23 MOD 18446744073709551616 = val w MOD 2 EXP 23` + (fun th -> REWRITE_TAC[th]) THENL + [GEN_TAC THEN MATCH_MP_TAC MOD_LT THEN + MP_TAC(ARITH_RULE `!x. x MOD 2 EXP 23 < 8388608`) THEN + DISCH_THEN(MP_TAC o SPEC `val(w:(32)word)`) THEN ARITH_TAC; + ALL_TAC] THEN + MP_TAC(SPECL [`inlist:(24 word)list`; `buf:int64`; `s1:x86state`; + `s2:x86state`; `8*N+i:num`; `280`] BYTE_BRIDGE_3BYTES_2STATE) THEN + ASM_REWRITE_TAC[] THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN CONV_TAC NUM_REDUCE_CONV);; + +(* ========================================================================= *) +(* MEMORY_CONJUNCT_CLOSURE lemma. *) +(* ========================================================================= *) + +(* MEMORY_CONJUNCT_CLOSURE — standalone lemma for closing the memory conjunct + in Case B (ACCEPT i+1=K, curlen+1<256). + + After X86_STEPS to s54 + ENSURES_FINAL_STATE_TAC + ASM_REWRITE, the memory + subgoal reduces to: + read (memory :> bytes (res, 4*(curlen+1))) s = num_of_wordlist (APPEND curlist [wa]) + with asms: + - curlen < 256 + - LENGTH curlist = curlen + - read (memory :> bytes (res, 4*curlen)) s = num_of_wordlist curlist + - read (memory :> bytes (word_add res (word (4*curlen)), 4)) s = val wa + + This lemma is specialized to wa:int32 so it matches the list type directly. + Using it inline: MATCH_MP_TAC MEMORY_CONJUNCT_CLOSURE THEN ASM_REWRITE_TAC[] *) + +let MEMORY_CONJUNCT_CLOSURE = prove + (`!(res:int64) (s:x86state) (curlist:int32 list) (curlen:num) (wa:int32). + curlen < 256 /\ + LENGTH curlist = curlen /\ + read (memory :> bytes (res, 4 * curlen)) s = num_of_wordlist curlist /\ + read (memory :> bytes (word_add res (word (4 * curlen)), 4)) s = val wa + ==> read (memory :> bytes (res, 4 * (curlen + 1))) s = + num_of_wordlist (APPEND curlist [wa])`, + REPEAT STRIP_TAC THEN + SUBGOAL_THEN `4 * (curlen + 1) = 4 * curlen + 4` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + W(MP_TAC o PART_MATCH (lhand o rand) BYTES_EQ_NUM_OF_WORDLIST_APPEND o snd) THEN + ASM_REWRITE_TAC[DIMINDEX_32; ARITH_RULE `8 * 4 * l = 32 * l`] THEN + REWRITE_TAC[num_of_wordlist; MULT_CLAUSES; ADD_CLAUSES] THEN + ASM_REWRITE_TAC[ADD_0]);; + +(* ========================================================================= *) +(* Case B closure helpers (VAL_RCX_ADD3). *) +(* ========================================================================= *) + +(* Helper lemmas for Case B. *) + +let VAL_RCX_ADD3 = prove + (`!N i:num. 24 * N + 3 * i <= 837 + ==> val(word_add (word_zx (word (24*N+3*i):int64):(32)word) (word 3:(32)word)) + = 24 * N + 3 * i + 3`, + REPEAT STRIP_TAC THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `!m. m <= 837 ==> m MOD 18446744073709551616 = m /\ + m MOD 4294967296 = m /\ + (m + 3) MOD 4294967296 = m + 3 /\ + (m + 3) MOD 18446744073709551616 = m + 3` + MP_TAC THENL + [GEN_TAC THEN DISCH_TAC THEN + REPEAT CONJ_TAC THEN MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `m <= 837` THEN ARITH_TAC; + DISCH_THEN(MP_TAC o SPEC `24 * N + 3 * i:num`)] THEN + ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[] THEN ARITH_TAC);; + +let VAL_RCX_ADD3_ZX = prove + (`!N i:num. 24 * N + 3 * i <= 837 + ==> val(word_zx(word_zx(word_add (word_zx (word (24*N+3*i):int64):(32)word) (word 3:(32)word)):(64)word):(32)word) + = 24 * N + 3 * i + 3`, + REPEAT STRIP_TAC THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[VAL_RCX_ADD3] THEN + SUBGOAL_THEN `(24 * N + 3 * i + 3) MOD 18446744073709551616 = 24 * N + 3 * i + 3 /\ + (24 * N + 3 * i + 3) MOD 4294967296 = 24 * N + 3 * i + 3` + STRIP_ASSUME_TAC THENL + [CONJ_TAC THEN MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + ASM_REWRITE_TAC[]]);; + +(* ========================================================================= *) +(* SCALAR_BODY_LEMMA (per-iteration specification). *) +(* ========================================================================= *) + +(* scalar_body_lemma.ml — proof of the scalar body subgoal. + Main proof uses MATCH_MP_TAC SCALAR_BODY_LEMMA; the wiring is verified + working at mldsa_rej_uniform.ml:1939. + + Status: structural proof loads in ~15s (down from 1hr) after extracting + PIVOT_VAL_EQ. + + Dependencies (must be loaded BEFORE this file): + - pivot_lemma.ml — PIVOT_VAL_EQ + - memory_conjunct_lemma.ml — MEMORY_CONJUNCT_CLOSURE + - case_b_close.ml — VAL_RCX_ADD3, VAL_RCX_ADD3_ZX + + 3 remaining CHEAT_TACs (all in the ACCEPT i+1=K offset-exit arm): + - count-exit branch: curlen+1=256 case (trivial closure, similar to Case A) + - Case A: offset-exit with curlen+1=256 + - Case B: offset-exit with curlen+1<256 (the interesting case — + interactively validated via case_b_script.ml with 0 CHEATs, + but file-load has subtle seqapply interaction — see reject_progress.md) +*) + +(* Extract 3 bytes of memory at offset 3*j from a 3*n-byte buffer (the natural + byte size for a (24 word)list of length n: 24*n bits = 3*n bytes). *) +let SCALAR_BODY_LEMMA = prove + (`!res buf table (inlist:(24 word)list) pc stackpointer N K i. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, 243) (res, 1024) /\ + nonoverlapping (word pc, 243) (buf, 840) /\ + nonoverlapping (word pc, 243) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + 24 * N <= 832 /\ + LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) inlist)) <= 256 /\ + i < K /\ ~(i = K) /\ 0 < K /\ + (!j. j < K + ==> LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*N + j) inlist)) < 256 /\ + 24 * N + 3 * j <= 837) /\ + (256 <= LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*N + K) inlist)) \/ + 837 < 24 * N + 3 * K) + ==> + ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST mldsa_rej_uniform_tmc) /\ + read RIP s = word(pc + 181) /\ + read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + (let outlist_i = REJ_SAMPLE(SUB_LIST(0, 8 * N + i) (inlist:(24 word)list)) in + let outlen_i = LENGTH outlist_i in + read RAX s = word outlen_i /\ + read RCX s = word(24 * N + 3 * i) /\ + read(memory :> bytes(res, 4 * outlen_i)) s = num_of_wordlist outlist_i)) + (\s. bytes_loaded s (word pc) (BUTLAST mldsa_rej_uniform_tmc) /\ + read RIP s = word (if i + 1 < K then pc + 181 else pc + 242) /\ + read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + (let outlist_j = REJ_SAMPLE(SUB_LIST(0, 8 * N + (i+1)) (inlist:(24 word)list)) in + let outlen_j = LENGTH outlist_j in + read RAX s = word outlen_j /\ + read RCX s = word(24 * N + 3 * (i+1)) /\ + read(memory :> bytes(res, 4 * outlen_j)) s = num_of_wordlist outlist_j)) + (MAYCHANGE [RIP; RAX; RCX; R8; R9; R10] ,, + MAYCHANGE [ZMM0; ZMM1; ZMM2; ZMM3; ZMM4; + ZMM5; ZMM6; ZMM7; ZMM8; ZMM9; ZMM10; ZMM11; + ZMM12; ZMM13; ZMM14; ZMM15] ,, + MAYCHANGE SOME_FLAGS ,, MAYCHANGE [events] ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + REPEAT GEN_TAC THEN REWRITE_TAC[NONOVERLAPPING_CLAUSES] THEN + (* Split the precondition conjunction: strip all conjuncts EXCEPT the final + disjunction (which we keep as a single assumption for late use). *) + DISCH_THEN(CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC ASSUME_TAC)))))))))))) THEN + FIRST_X_ASSUM(MP_TAC o C MATCH_MP (ASSUME `i:num < K`) o + check (is_forall o concl)) THEN STRIP_TAC THEN + ABBREV_TAC `curlist = REJ_SAMPLE(SUB_LIST(0, 8 * N + i) (inlist:(24 word)list))` THEN + ABBREV_TAC `curlen = LENGTH(curlist:int32 list)` THEN + SUBGOAL_THEN `curlen < 256` ASSUME_TAC THENL + [MAP_EVERY EXPAND_TAC ["curlen"; "curlist"] THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + CONV_TAC(RATOR_CONV(LAND_CONV(TOP_DEPTH_CONV let_CONV))) THEN + ASM_REWRITE_TAC[] THEN + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [36;37] THEN + SUBGOAL_THEN `read RIP s37 = word(pc + 188):int64` + (fun th -> TRY(FIRST_X_ASSUM(K ALL_TAC o check (fun th' -> + let c = concl th' in is_eq c && can (find_term is_cond) c && + can (find_term ((=) `s37:x86state`)) c && + can (find_term ((=) `RIP`)) c))) THEN ASSUME_TAC th) THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + can (find_term ((=) `RIP`)) (concl th) && is_eq(concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[REAL_EQ_SUB_RADD; REAL_OF_NUM_ADD; REAL_OF_NUM_EQ] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; ALL_TAC] THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [38;39] THEN + SUBGOAL_THEN `read RIP s39 = word(pc + 196):int64` + (fun th -> TRY(FIRST_X_ASSUM(K ALL_TAC o check (fun th' -> + let c = concl th' in is_eq c && can (find_term is_cond) c && + can (find_term ((=) `s39:x86state`)) c && + can (find_term ((=) `RIP`)) c))) THEN ASSUME_TAC th) THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + can (find_term ((=) `RIP`)) (concl th) && is_eq(concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[REAL_EQ_SUB_RADD; REAL_OF_NUM_ADD; REAL_OF_NUM_EQ] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (40--46) THEN + ABBREV_TAC `r8val:int64 = read R8 s46` THEN + ASM_CASES_TAC `val(r8val:int64) < 8380417` THENL + [(* ACCEPT branch: val(r8val) < 8380417, so JAE not taken; store + ADD + JMP *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [47] THEN + SUBGOAL_THEN `read RIP s47 = word(pc + 233):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + is_eq(concl th) && can (find_term ((=) `read RIP s47`)) (concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(MESON[] `~p ==> (q = (if p then (a:int64) else b) ==> q = b)`) THEN + (fun (asl, g) -> + let t32 = `:(32)word` in + let rec find_wa t = + if is_comb t then + let f, a = dest_comb t in + if is_comb f && is_const (fst(dest_comb f)) && + fst(dest_const(fst(dest_comb f))) = "word_and" && + type_of t = t32 && is_comb a && is_const(fst(dest_comb a)) && + fst(dest_const(fst(dest_comb a))) = "word" && + (try dest_small_numeral(snd(dest_comb a)) = 8388607 with _ -> false) + then Some t + else match find_wa f with Some r -> Some r | None -> find_wa a + else None in + match find_wa g with + | Some t -> + ABBREV_TAC (mk_eq(mk_var("zval", t32), t)) (asl, g) + | None -> failwith "zval abbrev: no match") THEN + REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(zval:(32)word) < 4294967296` ASSUME_TAC THENL + [MP_TAC(ISPEC `zval:(32)word` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `val(zval:(32)word) MOD 18446744073709551616 MOD 4294967296 = val zval` + SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `x < 4294967296 ==> x < 18446744073709551616`]; ALL_TAC] THEN + SUBGOAL_THEN `r8val:int64 = word_zx(zval:(32)word)` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in + is_eq c && fst(dest_eq c) = `r8val:int64`)) THEN + FIRST_X_ASSUM(SUBST1_TAC o SYM o check (fun th -> + let c = concl th in + is_eq c && snd(dest_eq c) = `zval:(32)word`)) THEN + DISCH_THEN ACCEPT_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(r8val:int64) = val(zval:(32)word)` ASSUME_TAC THENL + [UNDISCH_TAC `r8val:int64 = word_zx(zval:(32)word)` THEN + DISCH_THEN SUBST1_TAC THEN MATCH_MP_TAC VAL_WORD_ZX THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC; ALL_TAC] THEN + COND_CASES_TAC THENL + [UNDISCH_TAC `&8380417:int <= &(val(zval:(32)word))` THEN + UNDISCH_TAC `val(r8val:int64) = val(zval:(32)word)` THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LE; GSYM INT_OF_NUM_LT; + GSYM INT_OF_NUM_EQ] THEN INT_ARITH_TAC; + INT_ARITH_TAC]; ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s47 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read events s47 = CONS (EventJump (a, b)) c`] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [48] THEN + (* Convert the MOV store's bytes32 equation at s48 into a bytes(_,4) + equation, which VSTEPS can propagate through s49 (ADD) and s50 (JMP). *) + SUBGOAL_THEN + `read(memory :> bytes(word_add res (word(4 * val(word curlen:int64))),4)) + s48 = val(r8val:int64)` + ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o MATCH_MP BYTES32_TO_BYTES o check (fun th -> + let c = concl th in is_eq c && + can (find_term ((=) `bytes32`)) c && + can (find_term ((=) `s48:x86state`)) c)) THEN + FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in is_eq c && + can (find_term ((=) `r8val:int64`)) c && + fst(dest_eq c) = `r8val:int64`)) THEN + DISCH_THEN(fun th -> + REWRITE_TAC[th; VAL_WORD_ZX_GEN; DIMINDEX_32; DIMINDEX_64]) THEN + CONV_TAC NUM_REDUCE_CONV THEN + W(fun (_, g) -> + let t32 = `:(32)word` in + let rec find_wa t = + if is_comb t then + let f, a = dest_comb t in + if is_comb f && is_const (fst(dest_comb f)) && + fst(dest_const(fst(dest_comb f))) = "word_and" && + type_of t = t32 && is_comb a && is_const(fst(dest_comb a)) && + fst(dest_const(fst(dest_comb a))) = "word" && + (try dest_small_numeral(snd(dest_comb a)) = 8388607 + with _ -> false) + then Some t + else match find_wa f with Some r -> Some r | None -> find_wa a + else None in + match find_wa g with + | Some t -> + MP_TAC(ISPEC t VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV + | None -> failwith "VAL_BOUND search") THEN + DISCH_TAC THEN + ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `x < 4294967296 ==> x < 18446744073709551616`]; + ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [49;50] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8*N + i + 1) (inlist:(24 word)list)) = + APPEND curlist (REJ_SAMPLE(SUB_LIST(8*N + i, 1) inlist))` + ASSUME_TAC THENL + [SUBGOAL_THEN `8 * N + i + 1 = (8 * N + i) + 1` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + i`; `1:num`; `0:num`] + SUB_LIST_SPLIT) THEN + REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[REJ_SAMPLE_APPEND] THEN + ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `8 * N + i < 280` ASSUME_TAC THENL + [UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; ALL_TAC] THEN + (* ACCEPT byte-bridge: apply ACCEPT_REJ_SAMPLE_SINGLETON with the precise + hypotheses, gathered via narrow FIRST_X_ASSUM picks, to avoid the slow + ASM_REWRITE_TAC across the 280+ assumption list. *) + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(8*N + i, 1) (inlist:(24 word)list)) = + [word(val(r8val:int64)):int32]` + ASSUME_TAC THENL + [(* Normalize `1 * val(word(24*N+3*i))` → `3*(8*N+i)` so the r8val shape matches. *) + SUBGOAL_THEN `1 * val(word (24 * N + 3 * i):int64) = 3 * (8 * N + i) /\ + 1 * val(word (24 * N + 3 * i):int64) + 2 = 3 * (8 * N + i) + 2` + STRIP_ASSUME_TAC THENL + [REWRITE_TAC[MULT_CLAUSES; VAL_WORD; DIMINDEX_64] THEN + SUBGOAL_THEN `(24 * N + 3 * i) MOD 2 EXP 64 = 24 * N + 3 * i` + SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + ARITH_TAC; + ARITH_TAC]; + ALL_TAC] THEN + (* Fetch the 7 hypotheses for ACCEPT_REJ_SAMPLE_SINGLETON and feed them + directly, without ASM_REWRITE. *) + MP_TAC(SPECL [`inlist:(24 word)list`; `buf:int64`; `s39:x86state`; + `s40:x86state`; `r8val:int64`; `N:num`; `i:num`] + ACCEPT_REJ_SAMPLE_SINGLETON) THEN + ANTS_TAC THENL + [CONV_TAC NUM_REDUCE_CONV THEN + REPEAT CONJ_TAC THENL + [(* 1: LENGTH inlist = 280 *) FIRST_ASSUM ACCEPT_TAC; + (* 2: 8*N+i < 280 *) + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* 3: 3*(8*N+i)+3 <= 840 *) + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* 4: mem read s39 *) FIRST_ASSUM ACCEPT_TAC; + (* 5: mem read s40 *) FIRST_ASSUM ACCEPT_TAC; + (* 6: val r8val < 8380417 *) FIRST_ASSUM ACCEPT_TAC; + (* 7: r8val = word_zx(...(word 3*(8*N+i))...) — discharge via MP_TAC + on the r8val equation from asl (which uses `1*val(word(24*N+3*i))`) + and then ASM_REWRITE_TAC[] using the arith normalization. *) + FIRST_ASSUM(MP_TAC o check (fun th -> + let c = concl th in is_eq c && fst(dest_eq c) = `r8val:int64`)) THEN + ASM_REWRITE_TAC[]]; + DISCH_THEN(fun th -> REWRITE_TAC[CONJUNCT2 th])]; + ALL_TAC] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8*N + i + 1) (inlist:(24 word)list)) = + APPEND curlist [word(val(r8val:int64)):int32]` + ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `i + 1 < K` THENL + [ ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH] THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT CONJ_TAC THENL + [(* RAX: word_zx(word_add(word_zx(word curlen))(word 1)) = word(curlen+1) *) + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX: word_zx(word_add(word_zx(word(24*N+3*i)))(word 3)) = word(24*N+3*(i+1)) *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* Memory: bytes(res, 4*(curlen+1)) = num_of_wordlist (APPEND curlist [...]) *) + SUBGOAL_THEN `4 * (curlen + 1) = 4 * curlen + 4` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + (* Fold the RHS's big expanded word back to r8val *) + FIRST_ASSUM(fun th -> let c = concl th in + if is_eq c && fst(dest_eq c) = `r8val:int64` + then GEN_REWRITE_TAC (RAND_CONV o DEPTH_CONV) [SYM th] + else failwith "r8val fold") THEN + MP_TAC(ISPECL [`memory:(x86state,int64->byte)component`; `res:int64`; + `s50:x86state`; `curlist:int32 list`; + `[word(val(r8val:int64)):int32]`; `4 * curlen`; `4`] + BYTES_EQ_NUM_OF_WORDLIST_APPEND) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + (* Single-element write: num_of_wordlist [word(val r8val):int32] = + val(word(val r8val)) = val r8val (since < 2^32), and the bytes(_,4) + equation is propagated from s48 through VSTEPS 49-50. *) + REWRITE_TAC[num_of_wordlist; MULT_CLAUSES; ADD_CLAUSES] THEN + SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[]]; + (* MAYCHANGE closure *) + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* ACCEPT i+1=K branch: step through CMP EAX,256; JAE (pc+242) to reach + pc+242, using the strengthened lemma's WOP disjunct *) + SUBGOAL_THEN `i + 1 = K` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < K)` THEN UNDISCH_TAC `i < K` THEN ARITH_TAC; + ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [51;52] THEN + (* Split on WOP disjunct: count-exit vs offset-exit *) + FIRST_ASSUM(DISJ_CASES_TAC o check (fun th -> is_disj (concl th))) THENL + [(* count-exit: 256 <= LENGTH(REJ_SAMPLE ...), so curlen+1 = 256. + The ACCEPT branch has REJ_SAMPLE(0, 8*N+i+1) = APPEND curlist [r8val], + and with i+1=K we get length = curlen+1, so 256 <= curlen+1. + Combined with curlen < 256: curlen+1 = 256. *) + SUBGOAL_THEN `curlen + 1 = 256` ASSUME_TAC THENL + [UNDISCH_TAC `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + UNDISCH_TAC `i + 1 = K` THEN DISCH_THEN(fun th -> REWRITE_TAC[SYM th]) THEN + UNDISCH_TAC `REJ_SAMPLE (SUB_LIST (0,8 * N + i + 1) (inlist:(24 word)list)) = + APPEND curlist [word(val(r8val:int64)):int32]` THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[LENGTH_APPEND; LENGTH] THEN + UNDISCH_TAC `LENGTH (curlist:int32 list) = curlen` THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `read RIP s52 = word(pc + 242):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s52`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + SUBGOAL_THEN `val(word_add (word_zx (word curlen:int64):(32)word) (word 1:(32)word)) = curlen + 1` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`]; + ALL_TAC] THEN + ASM_REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s52 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read events s52 = CONS (EventJump (a, b)) c`] THEN + ENSURES_FINAL_STATE_TAC THEN + REWRITE_TAC[LET_DEF; LET_END_DEF] THEN + REPEAT CONJ_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH] THENL + [(* RAX *) + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX *) + FIRST_X_ASSUM (SUBST1_TAC o SYM o check (fun th -> concl th = `i + 1 = K`)) THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* Memory: bytes(res, 4*(curlen+1)) = num_of_wordlist (APPEND curlist [...]) *) + SUBGOAL_THEN `curlen + SUC 0 = curlen + 1` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `4 * (curlen + 1) = 4 * curlen + 4` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + FIRST_ASSUM(fun th -> let c = concl th in + if is_eq c && fst(dest_eq c) = `r8val:int64` + then GEN_REWRITE_TAC (RAND_CONV o DEPTH_CONV) [SYM th] + else failwith "r8val fold") THEN + MP_TAC(ISPECL [`memory:(x86state,int64->byte)component`; `res:int64`; + `s52:x86state`; `curlist:int32 list`; + `[word(val(r8val:int64)):int32]`; `4 * curlen`; `4`] + BYTES_EQ_NUM_OF_WORDLIST_APPEND) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + REWRITE_TAC[num_of_wordlist; MULT_CLAUSES; ADD_CLAUSES] THEN + SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[]]; + (* MAYCHANGE closure *) + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* offset-exit: 837 < 24*N+3*K, sub-split on 256 <= curlen+1 *) + ASM_CASES_TAC `256 <= curlen + 1` THENL + [(* Case A: curlen+1 = 256 (same output as count-exit). *) + SUBGOAL_THEN `curlen + 1 = 256` ASSUME_TAC THENL + [UNDISCH_TAC `256 <= curlen + 1` THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `read RIP s52 = word(pc + 242):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s52`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + SUBGOAL_THEN `val(word_add (word_zx (word curlen:int64):(32)word) (word 1:(32)word)) = curlen + 1` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`]; + ALL_TAC] THEN + ASM_REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s52 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read events s52 = CONS (EventJump (a, b)) c`] THEN + ENSURES_FINAL_STATE_TAC THEN + REWRITE_TAC[LET_DEF; LET_END_DEF] THEN + REPEAT CONJ_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH] THENL + [(* RAX *) + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX *) + FIRST_X_ASSUM (SUBST1_TAC o SYM o check (fun th -> concl th = `i + 1 = K`)) THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* Memory *) + SUBGOAL_THEN `curlen + SUC 0 = curlen + 1` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `4 * (curlen + 1) = 4 * curlen + 4` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + FIRST_ASSUM(fun th -> let c = concl th in + if is_eq c && fst(dest_eq c) = `r8val:int64` + then GEN_REWRITE_TAC (RAND_CONV o DEPTH_CONV) [SYM th] + else failwith "r8val fold") THEN + MP_TAC(ISPECL [`memory:(x86state,int64->byte)component`; `res:int64`; + `s52:x86state`; `curlist:int32 list`; + `[word(val(r8val:int64)):int32]`; `4 * curlen`; `4`] + BYTES_EQ_NUM_OF_WORDLIST_APPEND) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + REWRITE_TAC[num_of_wordlist; MULT_CLAUSES; ADD_CLAUSES] THEN + SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[]]; + (* MAYCHANGE *) + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* Case B: curlen+1 < 256 *) + (* Case B: curlen+1 < 256. Step through CMP ECX,837; JA at s52, + then X86_STEPS [53;54] after `wa` abbreviation, then close. *) + SUBGOAL_THEN `read RIP s52 = word(pc + 188):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s52`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + SUBGOAL_THEN `val(word_add (word_zx (word curlen:int64):(32)word) (word 1:(32)word)) = curlen + 1` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`]; + ALL_TAC] THEN + ASM_REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`; + ARITH_RULE `256 < 4294967296`] THEN + UNDISCH_TAC `~(256 <= curlen + 1)` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LE] THEN INT_ARITH_TAC; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s52 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read events s52 = CONS (EventJump (a, b)) c`] THEN + (* Abbreviate word_and sub-expression as `wa` to preserve r8val def *) + (fun (asl,g) -> + let rec findit = function + | [] -> failwith "no r8val def" + | (_, th) :: rest -> + let c = concl th in + if is_eq c && (try fst(dest_var(fst(dest_eq c))) = "r8val" with _ -> false) then + let rhs = snd(dest_eq c) in + (try let _, inner = dest_comb rhs in + ABBREV_TAC (mk_eq(mk_var("wa", type_of inner), inner)) (asl,g) + with _ -> findit rest) + else findit rest + in findit asl) THEN + SUBGOAL_THEN `val(r8val:int64) = val(wa:(32)word)` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in + is_eq c && (try fst(dest_var(fst(dest_eq c))) = "r8val" with _ -> false))) THEN + DISCH_THEN SUBST1_TAC THEN + MATCH_MP_TAC VAL_WORD_ZX THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `word(val(r8val:int64)):(32)word = wa` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN CONV_TAC WORD_BLAST; + ALL_TAC] THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [53;54] THEN + SUBGOAL_THEN `read RIP s54 = word(pc + 242):int64` ASSUME_TAC THENL + [MP_TAC(SPECL [`N:num`; `i:num`] VAL_RCX_ADD3_ZX) THEN + ANTS_TAC THENL [FIRST_ASSUM ACCEPT_TAC; ALL_TAC] THEN + DISCH_TAC THEN + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s54`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[VAL_WORD_SUB_CASES; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `837 <= 24 * N + 3 * i + 3` (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `~((24 * N + 3 * i + 3) - 837 = 0)` + (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + REWRITE_TAC[DE_MORGAN_THM; NOT_CLAUSES] THEN + MP_TAC(SPECL [`837:num`; `24 * N + 3 * i + 3`] INT_OF_NUM_SUB) THEN + ANTS_TAC THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN(fun th -> REWRITE_TAC[SYM th]) THEN INT_ARITH_TAC; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read RIP s54 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read events s54 = CONS (EventJump (a, b)) c`] THEN + (* Pre-establish augmented memory invariant via MEMORY_CONJUNCT_CLOSURE *) + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN + `read (memory :> bytes (res, 4 * (curlen + 1))) s54 = + num_of_wordlist (APPEND curlist [word(val(r8val:int64)):int32])` + ASSUME_TAC THENL + [SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + MATCH_MP_TAC MEMORY_CONJUNCT_CLOSURE THEN + REPEAT CONJ_TAC THENL + [FIRST_ASSUM ACCEPT_TAC; + FIRST_ASSUM ACCEPT_TAC; + FIRST_ASSUM ACCEPT_TAC; + ONCE_REWRITE_TAC[ASSUME `val(word(val(r8val:int64)):int32) = val r8val`] THEN + FIRST_ASSUM ACCEPT_TAC]; + ALL_TAC] THEN + UNDISCH_THEN `r8val:int64 = word_zx(wa:(32)word)` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th]) THEN ASSUME_TAC th) THEN + ENSURES_FINAL_STATE_TAC THEN + REWRITE_TAC[LET_DEF; LET_END_DEF] THEN + REPEAT CONJ_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH; + ARITH_RULE `curlen + SUC 0 = curlen + 1`] THENL + [(* RAX *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD_ADD; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + (* MAYCHANGE *) + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]]]]; + (* REJECT branch: val(r8val) >= 8380417, JAE taken to pc+181 *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [47] THEN + SUBGOAL_THEN `read RIP s47 = word(pc + 181):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + is_eq(concl th) && can (find_term ((=) `read RIP s47`)) (concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(MESON[] `p ==> (q = (if p then (a:int64) else b) ==> q = a)`) THEN + SUBGOAL_THEN `8380417 <= val(r8val:int64)` ASSUME_TAC THENL + [UNDISCH_TAC `~(val(r8val:int64) < 8380417)` THEN ARITH_TAC; ALL_TAC] THEN + (fun (asl, g) -> + let t32 = `:(32)word` in + let rec find_wa t = + if is_comb t then + let f, a = dest_comb t in + if is_comb f && is_const (fst(dest_comb f)) && + fst(dest_const(fst(dest_comb f))) = "word_and" && + type_of t = t32 && is_comb a && is_const(fst(dest_comb a)) && + fst(dest_const(fst(dest_comb a))) = "word" && + (try dest_small_numeral(snd(dest_comb a)) = 8388607 with _ -> false) + then Some t + else match find_wa f with Some r -> Some r | None -> find_wa a + else None in + match find_wa g with + | Some t -> + ABBREV_TAC (mk_eq(mk_var("zval", t32), t)) (asl, g) + | None -> failwith "zval abbrev: no match") THEN + REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(zval:(32)word) < 4294967296` ASSUME_TAC THENL + [MP_TAC(ISPEC `zval:(32)word` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `val(zval:(32)word) MOD 18446744073709551616 MOD 4294967296 = val zval` + SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `x < 4294967296 ==> x < 18446744073709551616`]; ALL_TAC] THEN + SUBGOAL_THEN `r8val:int64 = word_zx(zval:(32)word)` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in + is_eq c && fst(dest_eq c) = `r8val:int64`)) THEN + FIRST_X_ASSUM(SUBST1_TAC o SYM o check (fun th -> + let c = concl th in + is_eq c && snd(dest_eq c) = `zval:(32)word`)) THEN + DISCH_THEN ACCEPT_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(r8val:int64) = val(zval:(32)word)` ASSUME_TAC THENL + [UNDISCH_TAC `r8val:int64 = word_zx(zval:(32)word)` THEN + DISCH_THEN SUBST1_TAC THEN MATCH_MP_TAC VAL_WORD_ZX THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC; ALL_TAC] THEN + COND_CASES_TAC THENL + [REFL_TAC; + UNDISCH_TAC `~(&8380417:int <= &(val(zval:(32)word)))` THEN + UNDISCH_TAC `val(r8val:int64) = val(zval:(32)word)` THEN + UNDISCH_TAC `8380417 <= val(r8val:int64)` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LE; GSYM INT_OF_NUM_EQ] THEN + INT_ARITH_TAC]; ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s47 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read events s47 = CONS (EventJump (a, b)) c`] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8*N + i + 1) (inlist:(24 word)list)) = + APPEND curlist (REJ_SAMPLE(SUB_LIST(8*N + i, 1) inlist))` + ASSUME_TAC THENL + [SUBGOAL_THEN `8 * N + i + 1 = (8 * N + i) + 1` SUBST1_TAC THENL [ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + i`; `1:num`; `0:num`] SUB_LIST_SPLIT) THEN + REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[REJ_SAMPLE_APPEND] THEN + ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `8 * N + i < 280` ASSUME_TAC THENL + [UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; ALL_TAC] THEN + (* Pivot lemma: val r8val equals the 23 low bits of the list element. + Use the extracted PIVOT_VAL_EQ top-level lemma for fast application. *) + SUBGOAL_THEN `1 * val(word (24 * N + 3 * i):int64) = 3 * (8 * N + i) /\ + 1 * val(word (24 * N + 3 * i):int64) + 2 = 3 * (8 * N + i) + 2` + STRIP_ASSUME_TAC THENL + [REWRITE_TAC[MULT_CLAUSES; VAL_WORD; DIMINDEX_64] THEN + SUBGOAL_THEN `(24 * N + 3 * i) MOD 2 EXP 64 = 24 * N + 3 * i` + SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + ARITH_TAC; + ARITH_TAC]; + ALL_TAC] THEN + SUBGOAL_THEN + `val(r8val:int64) = val(EL (8*N+i) (inlist:(24 word)list)) MOD 2 EXP 23` + ASSUME_TAC THENL + [MP_TAC(SPECL [`inlist:(24 word)list`; `buf:int64`; `s39:x86state`; + `s40:x86state`; `r8val:int64`; `N:num`; `i:num`] + PIVOT_VAL_EQ) THEN + ASM_REWRITE_TAC[ARITH_RULE `3 * 280 = 840`] THEN + ANTS_TAC THENL + [UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + DISCH_THEN ACCEPT_TAC]; + ALL_TAC] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(8 * N + i, 1) (inlist:(24 word)list)) = []` + ASSUME_TAC THENL + [REWRITE_TAC[SUB_LIST_1] THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[REJ_SAMPLE; MAP; FILTER] THEN + REWRITE_TAC[VAL_MOD_23_EQ_AND] THEN + COND_CASES_TAC THENL + [SUBGOAL_THEN + `val (word_and (word_zx (EL (8 * N + i) (inlist:(24 word)list)):int32) + (word 8388607):int32) = + val(EL (8 * N + i) (inlist:(24 word)list)) MOD 2 EXP 23` + SUBST_ALL_TAC THENL + [REWRITE_TAC[GSYM VAL_MOD_23_EQ_AND; VAL_WORD; DIMINDEX_32] THEN + MATCH_MP_TAC MOD_LT THEN + MP_TAC(ISPEC `EL (8 * N + i) (inlist:(24 word)list)` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_24] THEN ARITH_TAC; + ALL_TAC] THEN + UNDISCH_TAC `~(val(r8val:int64) < 8380417)` THEN + ASM_REWRITE_TAC[] THEN ARITH_TAC; + REFL_TAC]; ALL_TAC] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8 * N + i + 1) (inlist:(24 word)list)) = curlist` + ASSUME_TAC THENL + [ASM_REWRITE_TAC[APPEND_NIL]; ALL_TAC] THEN + ASM_CASES_TAC `i + 1 < K` THENL + [ ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[APPEND_NIL] THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THENL + [(* RCX: word_zx(word_add(word_zx(word(24*N+3*i)))(word 3)) = word(24*N+3*(i+1)) *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* MAYCHANGE closure *) + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* i + 1 = K branch of REJECT — fully closed via WOP offset-exit. + Mirrors Case B ACCEPT i+1=K: JA not taken on curlen<256, then + CMP RCX,837 / JA taken to pc+242 using VAL_RCX_ADD3_ZX. *) + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [48] THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [49] THEN + SUBGOAL_THEN `read RIP s49 = word(pc + 188):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s49`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen MOD 18446744073709551616 MOD 4294967296 = curlen` + SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`]; + ALL_TAC] THEN + SUBGOAL_THEN `~(&256:int <= &curlen)` ASSUME_TAC THENL + [UNDISCH_TAC `curlen < 256` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_LE] THEN INT_ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `curlen < 256` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT] THEN + INT_ARITH_TAC; ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read RIP s49 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read events s49 = CONS (EventJump (a, b)) c`] THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [50] THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [51] THEN + FIRST_ASSUM(DISJ_CASES_TAC o check (fun th -> is_disj (concl th))) THENL + [SUBGOAL_THEN `i + 1 = K` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < K)` THEN UNDISCH_TAC `i < K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `F` MP_TAC THENL + [UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + UNDISCH_TAC `REJ_SAMPLE (SUB_LIST (0,8 * N + i + 1) inlist) = curlist` THEN + UNDISCH_TAC `i + 1 = K` THEN + DISCH_THEN(SUBST1_TAC o SYM) THEN + DISCH_THEN SUBST1_TAC THEN + UNDISCH_TAC `LENGTH (curlist:int32 list) = curlen` THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + MESON_TAC[]]; + SUBGOAL_THEN `i + 1 = K` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < K)` THEN UNDISCH_TAC `i < K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `read RIP s51 = word(pc + 242):int64` ASSUME_TAC THENL + [MP_TAC(SPECL [`N:num`; `i:num`] VAL_RCX_ADD3_ZX) THEN + ANTS_TAC THENL [FIRST_ASSUM ACCEPT_TAC; ALL_TAC] THEN + DISCH_TAC THEN + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s51`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[VAL_WORD_SUB_CASES; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `837 <= 24 * N + 3 * i + 3` (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `~((24 * N + 3 * i + 3) - 837 = 0)` + (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + REWRITE_TAC[DE_MORGAN_THM; NOT_CLAUSES] THEN + MP_TAC(SPECL [`837:num`; `24 * N + 3 * i + 3`] INT_OF_NUM_SUB) THEN + ANTS_TAC THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN(fun th -> REWRITE_TAC[SYM th]) THEN INT_ARITH_TAC; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read RIP s51 = (if p then (a:int64) else b)`] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read events s51 = CONS (EventJump (a, b)) c`] THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[APPEND_NIL] THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THENL + [ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]]]]);; + +(* ========================================================================= *) +(* Top-level MLDSA_REJ_UNIFORM_CORRECT proof. *) +(* ========================================================================= *) + +let MLDSA_REJ_UNIFORM_CORRECT = prove + (`!res buf table (inlist:(24 word)list) pc. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, 243) (res, 1024) /\ + nonoverlapping (word pc, 243) (buf, 840) /\ + nonoverlapping (word pc, 243) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) + ==> ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST mldsa_rej_uniform_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list)) + (\s. read RIP s = word(pc + 242) /\ + let outlist = SUB_LIST(0,256) (REJ_SAMPLE inlist) in + let outlen = LENGTH outlist in + C_RETURN s = word outlen /\ + read(memory :> bytes(res,4 * outlen)) s = + num_of_wordlist outlist) + (MAYCHANGE [RIP; RAX; RCX; R8; R9; R10] ,, + MAYCHANGE [ZMM0; ZMM1; ZMM2; ZMM3; ZMM4; + ZMM5; ZMM6; ZMM7; ZMM8; ZMM9; ZMM10; ZMM11; + ZMM12; ZMM13; ZMM14; ZMM15] ,, + MAYCHANGE SOME_FLAGS ,, MAYCHANGE [events] ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + + MAP_EVERY X_GEN_TAC + [`res:int64`; `buf:int64`; `table:int64`; + `inlist:(24 word)list`; `pc:num`] THEN + REWRITE_TAC[C_ARGUMENTS; C_RETURN; SOME_FLAGS; NONOVERLAPPING_CLAUSES] THEN + STRIP_TAC THEN + (* Introduce stackpointer as a ghost variable bound to the initial RSP. + SCALAR_BODY_LEMMA's invariant references `read RSP s = stackpointer`; + the ghost satisfies that without exposing stackpointer at the top level. *) + GHOST_INTRO_TAC `stackpointer:int64` `read RSP` THEN + + (* =================================================================== *) + (* Phase 1: WOP to find loop iteration count N. *) + (* *) + (* Thresholds 808/248 match the CMP instructions: *) + (* CMP eax, 0xF8 (248): JA exits if outlen > 248 *) + (* CMP ecx, 0x328 (808): JA exits if offset > 808 *) + (* At m < N, negation gives: 24*(m+1) <= 832 /\ LENGTH(...) <= 248 *) + (* IMPORTANT: use DISCH_THEN to avoid global NOT_LT rewrite. *) + (* =================================================================== *) + + SUBGOAL_THEN + `?i. 832 < 24 * (i + 1) \/ 248 < LENGTH(REJ_SAMPLE(SUB_LIST(0,8 * i) inlist))` + MP_TAC THENL + [EXISTS_TAC `280:num` THEN ARITH_TAC; + GEN_REWRITE_TAC LAND_CONV [num_WOP]] THEN + DISCH_THEN(X_CHOOSE_THEN `N:num` (CONJUNCTS_THEN2 ASSUME_TAC MP_TAC)) THEN + DISCH_THEN(fun th -> ASSUME_TAC(REWRITE_RULE[DE_MORGAN_THM; NOT_LT] th)) THEN + SUBGOAL_THEN `~(N = 0)` ASSUME_TAC THENL + [DISCH_TAC THEN FIRST_X_ASSUM(MP_TAC o check (is_disj o concl)) THEN + ASM_REWRITE_TAC[MULT_CLAUSES; ADD_CLAUSES; SUB_LIST_CLAUSES] THEN + REWRITE_TAC[REJ_SAMPLE_EMPTY; LENGTH] THEN ARITH_TAC; + ALL_TAC] THEN + + (* =================================================================== *) + (* Phase 2: ENSURES_WHILE_UP2_TAC for the SIMD loop. *) + (* *) + (* Loop head: pc+104 (instruction 18: CMP eax,0xF8) *) + (* Loop exit: pc+181 (instruction 36: scalar section entry) *) + (* UP2 automatically adds bytes_loaded to the invariant. *) + (* Bounds are derived from WOP inside the loop body, not stored. *) + (* =================================================================== *) + + ENSURES_WHILE_UP2_TAC `N:num` `pc + 104` `pc + 181` + `\i s. + read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + let outlist = REJ_SAMPLE(SUB_LIST(0,8*i) inlist) in + let outlen = LENGTH outlist in + read RAX s = word outlen /\ + read RCX s = word(24*i) /\ + read(memory :> bytes(res,4*outlen)) s = num_of_wordlist outlist` THEN + ASM_REWRITE_TAC[LT_REFL] THEN REPEAT CONJ_TAC THENL + + [(* ================================================================= *) + (* Subgoal 1: Pre-loop setup (instructions 1-17, pc -> pc+104) *) + (* FULLY PROVED *) + (* ================================================================= *) + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC (1--17) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[MULT_CLAUSES; ADD_CLAUSES; SUB_LIST_CLAUSES; REJ_SAMPLE_EMPTY; + LENGTH; num_of_wordlist] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[READ_COMPONENT_COMPOSE; READ_MEMORY_BYTES_TRIVIAL] THEN + CONV_TAC WORD_REDUCE_CONV; + + (* ================================================================= *) + (* Subgoal 2: Loop body preservation (i -> i+1) *) + (* *) + (* Structure: *) + (* (a) Derive bounds from WOP: curlen <= 248, 24*i <= 808 *) + (* (b) Simulate CMP+JA (instrs 18-19): resolve JA not taken *) + (* (c) Simulate CMP+JA (instrs 20-21): resolve JA not taken *) + (* (d) Simulate SIMD body (instrs 22-35) *) + (* (e) COND_CASES_TAC on i+1 < N *) + (* (f) Semantic proof: relate SIMD to REJ_SAMPLE *) + (* ================================================================= *) + + X_GEN_TAC `i:num` THEN STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + + ABBREV_TAC `curlist = REJ_SAMPLE (SUB_LIST(0,8*i) inlist)` THEN + ABBREV_TAC `curlen = LENGTH(curlist:int32 list)` THEN + CONV_TAC(RATOR_CONV(LAND_CONV(TOP_DEPTH_CONV let_CONV))) THEN + ASM_REWRITE_TAC[] THEN + + (* (a) Get bounds from WOP at i *) + FIRST_ASSUM(MP_TAC o C MATCH_MP (ASSUME `i:num < N`) o + check (fun th -> is_forall(concl th))) THEN + ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + + SUBGOAL_THEN `curlen <= 248` ASSUME_TAC THENL + [ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `24 * i <= 808` ASSUME_TAC THENL + [UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; ALL_TAC] THEN + + ENSURES_INIT_TAC "s0" THEN + + (* (b) Instructions 18-19: CMP eax,0xF8; JA — not taken *) + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [18;19] THEN + SUBGOAL_THEN `read RIP s19 = word(pc + 111):int64` + (RESOLVE_JA_ONLY_TAC `s19:x86state`) THENL + [RESOLVE_JA_CURLEN_TAC; ALL_TAC] THEN + + (* (c) Instructions 20-21: CMP ecx,0x328; JA — not taken *) + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [20;21] THEN + SUBGOAL_THEN `read RIP s21 = word(pc + 119):int64` + (RESOLVE_JA_ONLY_TAC `s21:x86state`) THENL + [RESOLVE_JA_OFFSET_TAC; ALL_TAC] THEN + + (* (d) SIMD body: all verbose to preserve VMOVDQU→VPSHUFB→VPAND chain *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (22--29) THEN + + (* Abbreviate the 8 masked coefficients from YMM3 after VPAND *) + (* Semantic bridge: use POPCNT_VMOVMSKPS_LEMMA to establish + R9 = word(LENGTH(FILTER)) for the 8 masked dword lanes. + The YMM3 at s26 = word_and(read YMM3 s25)(mask_broadcast). + After ASM_REWRITE, the read R9 s29 expands to the popcount + of the sign-bit mask, and the LEMMA matches directly. *) + SUBGOAL_THEN + `read R9 s29:int64 = + word(LENGTH(FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s26:int256) (0,32):int32; + word_subword (read YMM3 s26) (32,32); + word_subword (read YMM3 s26) (64,32); + word_subword (read YMM3 s26) (96,32); + word_subword (read YMM3 s26) (128,32); + word_subword (read YMM3 s26) (160,32); + word_subword (read YMM3 s26) (192,32); + word_subword (read YMM3 s26) (224,32)]))` + MP_TAC THENL + [ASM_REWRITE_TAC[] THEN + CONV_TAC(LAND_CONV(REWR_CONV word_zx)) THEN + REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + AP_TERM_TAC THEN + REPEAT(FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + can (find_term ((=) `s25:x86state`)) (concl th) || + can (find_term ((=) `s26:x86state`)) (concl th) || + can (find_term ((=) `s27:x86state`)) (concl th) || + can (find_term ((=) `s28:x86state`)) (concl th) || + can (find_term ((=) `s29:x86state`)) (concl th)))) THEN + SIMP_TAC[WORD_ZX_ZX; DIMINDEX_32; DIMINDEX_64; + ARITH_RULE `32 <= 64`] THEN + SIMP_TAC[WORD_POPCOUNT_WORD_ZX; DIMINDEX_8; DIMINDEX_32; + ARITH_RULE `8 <= 32`] THEN + REWRITE_TAC[VMOVMSKPS_POPCOUNT_EQ; BIT_NESTED_JOIN_8] THEN + REWRITE_TAC[POPCNT_VMOVMSKPS_LEMMA] THEN + MATCH_MP_TAC MOD_LT THEN + TRANS_TAC LTE_TRANS `9` THEN CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `n <= 8 ==> n < 9`) THEN + W(MP_TAC o PART_MATCH lhand LENGTH_FILTER o lhand o snd) THEN + REWRITE_TAC[LENGTH] THEN ARITH_TAC; + ARITH_TAC]; + DISCARD_MATCHING_ASSUMPTIONS [`read R9 s = x`] THEN STRIP_TAC] THEN + + (* SIMD↔spec: prove BEFORE DISCARD while YMM3/buffer hyps available. + newlen (the FILTER length over SIMD coefficients) = LENGTH(REJ_SAMPLE(...)) + This is the key semantic bridge: VPERMQ+VPSHUFB+VPAND = spec's MAP+FILTER. + The result is state-independent and survives DISCARD_OLDSTATE_TAC. + Approach: WORD_SIMPLE_SUBWORD_CONV reduces the 256-bit VPSHUFB chain + to clean 8-bit byte extractions from the bytes256 memory read. Then + bytes256 → bytes(32) → MOD 2^192 → num_of_wordlist(SUB_LIST). *) + SUBGOAL_THEN + `FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s26:int256) (0,32):int32; + word_subword (read YMM3 s26) (32,32); + word_subword (read YMM3 s26) (64,32); + word_subword (read YMM3 s26) (96,32); + word_subword (read YMM3 s26) (128,32); + word_subword (read YMM3 s26) (160,32); + word_subword (read YMM3 s26) (192,32); + word_subword (read YMM3 s26) (224,32)] = + REJ_SAMPLE(SUB_LIST(8*i,8) inlist)` + ASSUME_TAC THENL + [REWRITE_TAC[REJ_SAMPLE] THEN + REPEAT(FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + can (find_term ((=) `newlen:num`)) (concl th) || + can (find_term ((=) `R9`)) (concl th)))) THEN + REPEAT(FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + not(can (find_term ((=) `YMM3`)) c && + can (find_term ((=) (mk_var("s26",`:x86state`)))) c) && + not(can (find_term ((=) `inlist:(24 word)list`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "num_of_wordlist" with _ -> false)) c && + can (find_term ((=) (mk_var("s21",`:x86state`)))) c) && + (can (find_term (fun t -> try let s = fst(dest_var t) in + String.length s > 0 && s.[0] = 's' && s <> "stackpointer" + with _ -> false)) c || + can (find_term ((=) `MAYCHANGE`)) c || + can (find_term ((=) `bytes_loaded`)) c)))) THEN + FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `YMM3`)) (concl th) && + can (find_term ((=) (mk_var("s26",`:x86state`)))) (concl th) && + is_eq(concl th) + then GEN_REWRITE_TAC (ONCE_DEPTH_CONV) [th] + else failwith "") THEN + CONV_TAC(TOP_DEPTH_CONV + (REWR_CONV WORD_SUBWORD_AND ORELSEC WORD_SIMPLE_SUBWORD_CONV)) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + SUBGOAL_THEN `1 * val(word(24 * i):int64) = 24 * i` SUBST1_TAC THENL + [REWRITE_TAC[MULT_CLAUSES; VAL_WORD; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + ALL_TAC] THEN + REWRITE_TAC[BYTE_JOIN_ZX; BYTE_JOIN_SUBWORD_ZX; + bytes256; READ_COMPONENT_COMPOSE; asword; through; read] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + SUBGOAL_THEN + `read(memory :> bytes(word_add buf (word(24*i)),24)) s21 = + num_of_wordlist(SUB_LIST(8*i,8) (inlist:(24 word)list))` + ASSUME_TAC THENL + [REWRITE_TAC[NUM_OF_WORDLIST_SUB_LIST; DIMINDEX_24] THEN + CONV_TAC NUM_REDUCE_CONV THEN + FIRST_ASSUM(fun th -> + if is_eq(concl th) && + can (find_term (fun t -> + try fst(dest_const t) = "num_of_wordlist" with _ -> false)) (concl th) && + not(can (find_term (fun t -> + try fst(dest_const t) = "SUB_LIST" with _ -> false)) (concl th)) && + (let s = string_of_term(concl th) in + let n = String.length s in + let rec has840 j = j + 2 < n && + (s.[j] = '8' && s.[j+1] = '4' && s.[j+2] = '0' || has840 (j+1)) in + has840 0) + then GEN_REWRITE_TAC (RAND_CONV o LAND_CONV o LAND_CONV) [GSYM th] + else failwith "") THEN + REWRITE_TAC[GSYM READ_BYTES_DIV; GSYM READ_BYTES_MOD; + ARITH_RULE `8 * (24 * i) = 192 * i`; + ARITH_RULE `8 * 24 = 192`] THEN + REWRITE_TAC[READ_COMPONENT_COMPOSE; READ_BYTES_DIV; READ_BYTES_MOD] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `MIN (840 - 24 * i) 24 = 24` SUBST1_TAC THENL + [REWRITE_TAC[MIN] THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + REWRITE_TAC[ARITH_RULE `24 * 8 * i = 8 * (24 * i)`] THEN + GEN_REWRITE_TAC (RAND_CONV o ONCE_DEPTH_CONV) + [GSYM(NUM_REDUCE_CONV `2 EXP (8 * 24)`)] THEN + REWRITE_TAC[READ_BYTES_DIV; READ_BYTES_MOD] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `MIN (840 - 24 * i) 24 = 24` SUBST1_TAC THENL + [REWRITE_TAC[MIN] THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + REFL_TAC]]; + ALL_TAC] THEN + SUBGOAL_THEN + `read(bytes(word_add buf (word(24*i)),32))(read memory s21) MOD + 2 EXP 192 = + num_of_wordlist(SUB_LIST(8*i,8) (inlist:(24 word)list))` + MP_TAC THENL + [FIRST_X_ASSUM(MP_TAC o REWRITE_RULE[READ_COMPONENT_COMPOSE]) THEN + DISCH_THEN(SUBST1_TAC o SYM) THEN + GEN_REWRITE_TAC (LAND_CONV o ONCE_DEPTH_CONV) + [GSYM(NUM_REDUCE_CONV `8 * 24`)] THEN + REWRITE_TAC[READ_BYTES_MOD] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[MIN; ARITH_RULE `24 <= 32`]; + ALL_TAC] THEN + ABBREV_TAC + `n32 = read(bytes(word_add buf (word(24*i)),32))(read memory s21)` THEN + DISCH_TAC THEN + ASM_REWRITE_TAC[VAL_MOD_23_EQ_AND; COEFF_FROM_BYTES; + EL_NUM_OF_WORDLIST; NUM_OF_WORDLIST_SUB_LIST; + DIMINDEX_24] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + ASM_REWRITE_TAC[] THEN + (* COEFF_BYTE_JOIN_WORD converts byte_join coefficients to word(n MOD 2^256 DIV 2^ofs MOD 2^23). + Use GEN_REWRITE with concrete instances for each of the 8 offsets. *) + (* Combined COEFF + MOD_256_192: byte_join coeffs → word(n32 MOD 2^192 DIV 2^k MOD 2^23) *) + GEN_REWRITE_TAC (LAND_CONV o DEPTH_CONV) + (map (fun k -> + let kterm = mk_small_numeral k in + let coeff_th = CONV_RULE NUM_REDUCE_CONV + (SPECL [`n32:num`; kterm] COEFF_BYTE_JOIN_WORD) in + let mod_th = CONV_RULE NUM_REDUCE_CONV + (SPECL [`n32:num`; kterm] MOD_256_192) in + CONV_RULE NUM_REDUCE_CONV (TRANS coeff_th (AP_TERM `word:num->int32` mod_th))) + [0;24;48;72;96;120;144;168]) THEN + CONV_TAC NUM_REDUCE_CONV THEN REWRITE_TAC[DIV_1] THEN + (* Convert huge 2^192 numeral back to 2 EXP 192 so hypothesis can match *) + REWRITE_TAC[GSYM(NUM_REDUCE_CONV `2 EXP 192`)] THEN + ASM_REWRITE_TAC[] THEN + (* Prove LENGTH(SUB_LIST(8*i,8) inlist) = 8 for REJ_SAMPLE_COEFFS *) + SUBGOAL_THEN `LENGTH(SUB_LIST(8*i,8) (inlist:(24 word)list)) = 8` + ASSUME_TAC THENL + [REWRITE_TAC[LENGTH_SUB_LIST] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC(ARITH_RULE + `8 * i + 8 <= l ==> MIN 8 (l - 8 * i) = 8`) THEN + ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[CONV_RULE NUM_REDUCE_CONV MAP_REJ_COEFFS]; + ALL_TAC] THEN + + (* Derive LENGTH from FILTER equality for the ABBREV *) + FIRST_X_ASSUM(fun th -> + if can (find_term (fun t -> try fst(dest_const t) = "FILTER" with _ -> false)) (concl th) && + can (find_term (fun t -> try fst(dest_const t) = "REJ_SAMPLE" with _ -> false)) (concl th) && + is_eq(concl th) && + not(can (find_term ((=) `LENGTH:(int32 list)->num`)) (concl th)) + then ASSUME_TAC th THEN ASSUME_TAC(AP_TERM `LENGTH:(int32 list)->num` th) + else failwith "not filter_eq") THEN + + (* Abbreviate the FILTER length to prevent DISCARD from seeing s26 ref *) + ABBREV_TAC `newlen = LENGTH(FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s26:int256) (0,32):int32; + word_subword (read YMM3 s26) (32,32); + word_subword (read YMM3 s26) (64,32); + word_subword (read YMM3 s26) (96,32); + word_subword (read YMM3 s26) (128,32); + word_subword (read YMM3 s26) (160,32); + word_subword (read YMM3 s26) (192,32); + word_subword (read YMM3 s26) (224,32)])` THEN + + (* The hypothesis `newlen = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) inlist))` + already exists from the SIMD subgoal proof. It's state-free and + survives DISCARD. No need to re-derive it. *) + + (* Derive FILTER = REJ_SAMPLE BEFORE ABBREVs, while all state hyps exist. + The SIMD subgoal proved LENGTH equality. Now prove FILTER equality + by using read YMM3 s26 = read YMM3 s29 (unchanged through s27-s29). *) + SUBGOAL_THEN + `FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s29:int256) (0,32):int32; + word_subword (read YMM3 s29) (32,32); + word_subword (read YMM3 s29) (64,32); + word_subword (read YMM3 s29) (96,32); + word_subword (read YMM3 s29) (128,32); + word_subword (read YMM3 s29) (160,32); + word_subword (read YMM3 s29) (192,32); + word_subword (read YMM3 s29) (224,32)] = + REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list))` + ASSUME_TAC THENL + [(* Use the FILTER=REJ_SAMPLE hypothesis (s26 version) to reduce to + FILTER P [s29 lanes] = FILTER P [s26 lanes], then ASM_REWRITE closes + because read YMM3 s29 = read YMM3 s26 (same VPAND EXPR). *) + FIRST_X_ASSUM(SUBST1_TAC o SYM o check (fun th -> + can (find_term (fun t -> try fst(dest_const t) = "FILTER" with _ -> false)) (concl th) && + can (find_term (fun t -> try fst(dest_const t) = "REJ_SAMPLE" with _ -> false)) (concl th) && + is_eq(concl th) && + not(can (find_term ((=) `LENGTH:(int32 list)->num`)) (concl th)))) THEN + ASM_REWRITE_TAC[]; + ALL_TAC] THEN + + (* Save the 8 bounds val(word_subword (read YMM3 s29) (k,32)) < 8388608 + BEFORE ABBREV substitutes coeffs_ymm3. Otherwise these bounds are + consumed as intermediate lemmas during CMP_MASK discharge and the + later VPERMD block's Step F picker (which looks for + `word_subword coeffs_ymm3 (k,32) ... < 8388608`) fails with Not_found. *) + SUBGOAL_THEN + `val(word_subword (read YMM3 s29:int256) (0,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (32,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (64,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (96,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (128,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (160,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (192,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (224,32):int32) < 8388608` + STRIP_ASSUME_TAC THENL + [FIRST_ASSUM(fun th -> + let c = concl th in + if is_eq c && + (try fst(dest_const(fst(strip_comb(rhs c)))) = "word_and" with _ -> false) && + (try let ops,args = strip_comb(lhs c) in + fst(dest_const ops) = "read" && + List.length args = 2 && + fst(dest_const(List.hd args)) = "YMM3" + with _ -> false) + then SUBST1_TAC th + else failwith "no YMM3 word_and") THEN + REWRITE_TAC[WORD_SUBWORD_AND] THEN + CONV_TAC(DEPTH_CONV(WORD_SIMPLE_SUBWORD_CONV ORELSEC WORD_NUM_RED_CONV)) THEN + REPEAT CONJ_TAC THEN + MATCH_MP_TAC(ARITH_RULE `n <= 8388607 ==> n < 8388608`) THEN + REWRITE_TAC[VAL_WORD_AND_WORD_LE]; + ALL_TAC] THEN + + (* Ghost state: ABBREV key s29 values before DISCARD kills them. *) + ABBREV_TAC `coeffs_ymm3:int256 = read YMM3 s29` THEN + ABBREV_TAC `cmp_mask:int64 = read R8 s29` THEN + ABBREV_TAC `table_entry:int64 = + read (memory :> bytes64 (word_add table (word (8 * val (cmp_mask:int64))))) s29` THEN + + (* Preserve cmp_mask ↔ coefficient comparison relationship. + cmp_mask encodes which coefficients pass val < Q via VMOVMSKPS. + This connects cmp_mask to the FILTER predicate for the brute force. *) + SUBGOAL_THEN + `val(cmp_mask:int64) = + bitval(val(word_subword (coeffs_ymm3:int256) (0,32):int32) < 8380417) + + 2 * bitval(val(word_subword (coeffs_ymm3:int256) (32,32):int32) < 8380417) + + 4 * bitval(val(word_subword (coeffs_ymm3:int256) (64,32):int32) < 8380417) + + 8 * bitval(val(word_subword (coeffs_ymm3:int256) (96,32):int32) < 8380417) + + 16 * bitval(val(word_subword (coeffs_ymm3:int256) (128,32):int32) < 8380417) + + 32 * bitval(val(word_subword (coeffs_ymm3:int256) (160,32):int32) < 8380417) + + 64 * bitval(val(word_subword (coeffs_ymm3:int256) (192,32):int32) < 8380417) + + 128 * bitval(val(word_subword (coeffs_ymm3:int256) (224,32):int32) < 8380417)` + ASSUME_TAC THENL + [(* Use CMP_MASK_CORRECT: rewrite H31 (cmp_mask ABBREV) using GSYM H30 + (coeffs_ymm3 ABBREV) to replace the VPAND chain with coeffs_ymm3, + then apply the lemma directly. *) + FIRST_ASSUM(fun h30 -> + if is_eq(concl h30) && is_var(lhs(concl h30)) && + (try fst(dest_var(lhs(concl h30))) = "coeffs_ymm3" with _ -> false) && + (try fst(dest_const(fst(strip_comb(rhs(concl h30))))) = "word_and" + with _ -> false) + then + FIRST_ASSUM(fun h31 -> + if is_eq(concl h31) && is_var(lhs(concl h31)) && + (try fst(dest_var(lhs(concl h31))) = "cmp_mask" with _ -> false) && + (try fst(dest_const(fst(strip_comb(rhs(concl h31))))) = "word_zx" + with _ -> false) + then + SUBST1_TAC(REWRITE_RULE[GSYM h30] h31) + else failwith "not h31") + else failwith "not h30") THEN + (* Normalize the bit-31/word_subword/word-of-sum shape to match + CMP_MASK_CORRECT's word_of_bits form: first collapse the + bit 31 (word_subword x (k,32)) patterns via GSYM BIT_SUBWORD_256 + (so we see bit (32k+31) of the nested word_join), then fold the + word(sum of bitvals) via GSYM VMOVMSKPS_BYTE_EQ into word_of_bits. *) + REWRITE_TAC[GSYM BIT_SUBWORD_256; GSYM VMOVMSKPS_BYTE_EQ] THEN + MATCH_MP_TAC(ISPECL [ + `word_subword (coeffs_ymm3:int256) (0,32):int32`; + `word_subword (coeffs_ymm3:int256) (32,32):int32`; + `word_subword (coeffs_ymm3:int256) (64,32):int32`; + `word_subword (coeffs_ymm3:int256) (96,32):int32`; + `word_subword (coeffs_ymm3:int256) (128,32):int32`; + `word_subword (coeffs_ymm3:int256) (160,32):int32`; + `word_subword (coeffs_ymm3:int256) (192,32):int32`; + `word_subword (coeffs_ymm3:int256) (224,32):int32` + ] CMP_MASK_CORRECT) THEN + (* Prove val(word_subword coeffs_ymm3 (k,32)) < 8388608 for each k. + coeffs_ymm3 = word_and(big)(mask) so word_subword distributes, + mask subword = word 8388607, then VAL_WORD_AND_WORD_LE gives bound. *) + FIRST_ASSUM(fun h30 -> + if is_eq(concl h30) && is_var(lhs(concl h30)) && + (try fst(dest_var(lhs(concl h30))) = "coeffs_ymm3" with _ -> false) && + (try fst(dest_const(fst(strip_comb(rhs(concl h30))))) = "word_and" + with _ -> false) + then SUBST1_TAC h30 + else failwith "") THEN + REWRITE_TAC[WORD_SUBWORD_AND] THEN + CONV_TAC(DEPTH_CONV(WORD_SIMPLE_SUBWORD_CONV ORELSEC WORD_NUM_RED_CONV)) THEN + REPEAT CONJ_TAC THEN + MATCH_MP_TAC(ARITH_RULE `n <= 8388607 ==> n < 8388608`) THEN + REWRITE_TAC[VAL_WORD_AND_WORD_LE]; + ALL_TAC] THEN + + DISCARD_OLDSTATE_TAC "s29" THEN CLARIFY_TAC THEN + (* Step 30-32 WITHOUT discard to keep VPERMD hypothesis chain *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (30--32) THEN + SUBGOAL_THEN + `val(read YMM3 s32:int256) MOD 2 EXP (32 * newlen) = + num_of_wordlist(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` + ASSUME_TAC THENL + [(* VPERMD: apply MLDSA_VPERMD_BRIDGE (proven in defs_extra.ml) + — replaces the former 256-case brute force, eliminating 255 cheats. *) + (* Step A: derive val(table_entry) = (table DIV 2^(64*val cmp_mask)) MOD 2^64 *) + SUBGOAL_THEN + `val(table_entry:int64) = + (num_of_wordlist mldsa_rej_uniform_table DIV + 2 EXP (64 * val(cmp_mask:int64))) MOD 2 EXP 64` + ASSUME_TAC THENL + [SUBGOAL_THEN + `val(read(memory :> bytes64(word_add table (word(8 * val(cmp_mask:int64))))) s32 :int64) = + (num_of_wordlist mldsa_rej_uniform_table DIV 2 EXP (64 * val cmp_mask)) MOD 2 EXP 64` + MP_TAC THENL + [MATCH_MP_TAC TABLE_ENTRY_FROM_MEMORY THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `bitval`)) (concl th) && is_eq(concl th) && + (try fst(dest_var(rand(lhs(concl th)))) = "cmp_mask" with _ -> false) + then SUBST1_TAC th else failwith "") THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (0,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (32,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (64,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (96,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (128,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (160,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (192,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (224,32):int32) < 8380417` BITVAL_BOUND) THEN + ARITH_TAC]; + ASM_REWRITE_TAC[]]; ALL_TAC] THEN + (* Step B: substitute read YMM3 s32 into goal (exposes the VPERMD expansion). *) + FIRST_X_ASSUM (fun th -> + let s = string_of_term (concl th) in + let n = String.length s in + let contains needle = + let k = String.length needle in + let rec scan i = i + k <= n && (String.sub s i k = needle || scan (i+1)) in + scan 0 in + if contains "read YMM3 s32" && is_eq(concl th) && not(contains "MAYCHANGE") + then GEN_REWRITE_TAC (LAND_CONV o ONCE_DEPTH_CONV) [th] THEN ASSUME_TAC th + else failwith "not ymm3 s32") THEN + (* Step C: rewrite (32 * newlen) → (32 * bitval_sum) via newlen = LENGTH(FILTER) + and FILTER=REJ_SAMPLE; also convert RHS REJ_SAMPLE → FILTER. *) + (fun (asl, w) -> + let contains_s needle s = + let n = String.length needle in let m = String.length s in + let rec scan i = i + n <= m && (String.sub s i n = needle || scan (i+1)) in + scan 0 in + let b k = + let needle = Printf.sprintf "word_subword coeffs_ymm3 (%d,32)" k in + snd(List.find (fun (_,th) -> + let s = string_of_term (concl th) in + contains_s needle s && contains_s "< 8388608" s && not(contains_s "==>" s)) asl) in + let bounds = CONJ (b 0) (CONJ (b 32) (CONJ (b 64) (CONJ (b 96) (CONJ (b 128) (CONJ (b 160) (CONJ (b 192) (b 224))))))) in + let flt_spec = ISPECL [ + `word_subword (coeffs_ymm3:int256) (0,32):int32`; + `word_subword (coeffs_ymm3:int256) (32,32):int32`; + `word_subword (coeffs_ymm3:int256) (64,32):int32`; + `word_subword (coeffs_ymm3:int256) (96,32):int32`; + `word_subword (coeffs_ymm3:int256) (128,32):int32`; + `word_subword (coeffs_ymm3:int256) (160,32):int32`; + `word_subword (coeffs_ymm3:int256) (192,32):int32`; + `word_subword (coeffs_ymm3:int256) (224,32):int32` + ] FILTER_LENGTH_8 in + let length_filter_eq = MP flt_spec bounds in + let newlen_def = snd(List.find (fun (_, th) -> + is_eq(concl th) && is_var(lhs(concl th)) && + (try fst(dest_var(lhs(concl th))) = "newlen" with _ -> false)) asl) in + let filter_rej_eq = snd(List.find (fun (_, th) -> + let s = string_of_term (concl th) in + contains_s "FILTER" s && contains_s "REJ_SAMPLE" s && is_eq(concl th) && + not(contains_s "LENGTH" s) && not(contains_s "==>" s)) asl) in + let newlen_bitval = TRANS (TRANS newlen_def + (AP_TERM `LENGTH:int32 list -> num` (SYM filter_rej_eq))) length_filter_eq in + REWRITE_TAC[newlen_bitval; SYM filter_rej_eq] (asl, w)) THEN + (* Step D: fold raw memory read back to table_entry, then collapse word_zx(word_zx ...) via + WORD_SIMPLE_SUBWORD_CONV to expose word_subword table_entry (k,3). *) + (fun (asl, w) -> + let contains_s needle s = + let n = String.length needle in let m = String.length s in + let rec scan i = i + n <= m && (String.sub s i n = needle || scan (i+1)) in + scan 0 in + let cm_sym = + let th = snd(List.find (fun (_, th) -> + is_eq(concl th) && + (try fst(dest_var(rand(lhs(concl th)))) = "cmp_mask" with _ -> false) && + contains_s "bitval" (string_of_term (concl th))) asl) in + SYM th in + let te_eqs = List.filter_map (fun (_, th) -> + let s = string_of_term (concl th) in + if is_eq(concl th) && contains_s "table_entry" s && contains_s "bytes64" s + then Some th else None) asl in + (REWRITE_TAC[cm_sym] THEN REWRITE_TAC te_eqs THEN + CONV_TAC(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV)) (asl, w)) THEN + (* Step E: apply MLDSA_VPERMD_BRIDGE specialized to coeffs_ymm3 and table_entry. *) + MATCH_MP_TAC (ISPECL [`coeffs_ymm3:int256`; `table_entry:int64`] MLDSA_VPERMD_BRIDGE) THEN + (* Step F: discharge the antecedent using targeted rewrites (bounds + te_val + cm_sym). + Full ASM_REWRITE_TAC hangs on the large assumption list, but this focused + set closes the 9 antecedent conjuncts in ~2s. *) + W(fun (asl,_) -> + let contains_s needle s = + let n = String.length needle in let m = String.length s in + let rec scan i = i + n <= m && (String.sub s i n = needle || scan (i+1)) in + scan 0 in + let b k = + let needle = Printf.sprintf "word_subword coeffs_ymm3 (%d,32)" k in + snd(List.find (fun (_,th) -> + let s = string_of_term (concl th) in + contains_s needle s && contains_s "< 8388608" s && not(contains_s "==>" s)) asl) in + let cm_sym = + let th = snd(List.find (fun (_, th) -> + is_eq(concl th) && + (try fst(dest_var(rand(lhs(concl th)))) = "cmp_mask" with _ -> false) && + contains_s "bitval" (string_of_term (concl th))) asl) in + SYM th in + let te_val = snd(List.find (fun (_, th) -> + is_eq(concl th) && + (try let l = lhs(concl th) in + fst(dest_comb l) = `val:int64->num` && + fst(dest_var(rand l)) = "table_entry" + with _ -> false) && + contains_s "DIV" (string_of_term (concl th))) asl) in + REWRITE_TAC [b 0; b 32; b 64; b 96; b 128; b 160; b 192; b 224; te_val; cm_sym]); + ALL_TAC] THEN + (* VSTEPS for all 3 steps to preserve bytes256 + VPERMD hyps *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (33--35) THEN + + (* (e) COND_CASES_TAC: continue (i+1 < N) vs exit (~(i+1 < N)) *) + COND_CASES_TAC THENL + [(* i+1 < N: continue looping *) + (* Derive new region memory content BEFORE ENSURES consumes state. + From the bytes256 write hypothesis (VMOVDQU step), derive: + read(memory :> bytes(addr, 32)) sN = val(bytes256 RHS) + with address normalized (val(word curlen) → curlen). + This is used by VPERMD_MEMORY_BRIDGE in the memory store goal. *) + (fun (asl,w) -> + try + (* Find bytes256 hyp with s35 and res in address *) + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (fun t -> try fst(dest_const t) = "bytes256" with _ -> false)) (concl th) && + can (find_term (fun t -> try fst(dest_var t) = "res" with _ -> false)) (concl th) && + can (find_term (fun t -> try fst(dest_var t) = "s35" with _ -> false)) (concl th) && + not(can (find_term (fun t -> try fst(dest_const t) = "MAYCHANGE" with _ -> false)) (concl th))) asl) in + (* Find read YMM3 s35 = to get clean RHS *) + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let ymm3_s35 = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_var "s35")) (lhs(concl th)) && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + not(can (find_term (has_const "MOD")) (concl th)) && + not(can (find_term (has_const "bytes256")) (concl th))) asl) in + (* Chain: bytes256 s35 = = YMM3 s35 by transitivity *) + let b256_ymm3 = TRANS b256_th (SYM ymm3_s35) in + (* Normalize address: val(word curlen) → curlen *) + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let mk_norm dim_thm dim_val = + let vwe = CONV_RULE NUM_REDUCE_CONV (REWRITE_RULE[dim_thm] (INST_TYPE [dim_val,`:N`] VAL_WORD_EQ)) in + let lt = CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) curlen_bound) in + try MATCH_MP vwe lt with _ -> + let lt64 = CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) curlen_bound) in + MATCH_MP vwe lt64 in + let curlen_norm_32 = mk_norm DIMINDEX_32 `:32` in + let curlen_norm_64 = mk_norm DIMINDEX_64 `:64` in + let b256_norm = REWRITE_RULE[curlen_norm_32; curlen_norm_64] b256_ymm3 in + (* Convert: val(bytes256 addr s35) = val(read YMM3 s35) + address normalize *) + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + (* Result: read(memory :> bytes(addr,32)) s35 = val(read YMM3 s35) *) + ASSUME_TAC bytes32_eq (asl,w) + with e -> + ALL_TAC (asl,w)) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + (* Establish iteration bounds *) + SUBGOAL_THEN `8 * (i + 1) <= LENGTH(inlist:(24 word)list)` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + (* Use the SIMD↔spec result: newlen = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8))) + ABBREV_TAC replaced FILTER... with newlen in this hypothesis *) + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + is_eq(concl th) && + can (find_term ((=) `newlen:num`)) (concl th) && + can (find_term (fun t -> + try fst(dest_const t) = "REJ_SAMPLE" with _ -> false)) (concl th))) THEN + (* Apply SIMD_ITERATION_BRIDGE to split REJ_SAMPLE across iterations *) + MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; `curlen:num`] + SIMD_ITERATION_BRIDGE) THEN + ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; NUM_OF_WORDLIST_APPEND] THEN + (* Clean state hypotheses before word arithmetic — keep MAYCHANGE and memory *) + DISCARD_ASSUMPTIONS_TAC (fun th -> + let c = concl th in + (can (term_match [] `read x s = (y:num)`) c && + not (can (find_term (fun t -> try fst(dest_const t) = "memory" with _ -> false)) c)) || + can (term_match [] `bytes_loaded x y z`) c || + can (term_match [] `read x s <=> y`) c) THEN + ABBREV_TAC `lenrej = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) inlist))` THEN + REPEAT CONJ_TAC THENL + [(* RAX: word(curlen + lenrej) — word arithmetic. + Use targeted UNDISCH (not ASM_ARITH_TAC — hangs on ~200 asl). *) + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lenrej <= 8` THEN + SPEC_TAC(`lenrej:num`, `l:num`) THEN GEN_TAC THEN + SPEC_TAC(`curlen:num`, `c:num`) THEN GEN_TAC THEN + REPEAT DISCH_TAC THEN + SUBGOAL_THEN `c < 4294967296 /\ c < 18446744073709551616 /\ + l < 4294967296 /\ l < 18446744073709551616 /\ + c + l < 4294967296 /\ c + l < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `c <= 248` THEN UNDISCH_TAC `l <= 8` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + (* RCX: word(24*(i+1)) — word arithmetic *) + REWRITE_TAC[ARITH_RULE `24 * (i + 1) = 24 * i + 24`] THEN + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`, `n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + (* Memory store: use VPERMD_MEMORY_BRIDGE + We have (from PRE-ENSURES): + read(memory :> bytes(addr, 32)) s35 = val(read YMM3 sN) + And (from VPERMD): + val(read YMM3 sN) MOD 2^(32*lenrej) = num_of_wordlist(REJ_SAMPLE(...)) + VPERMD_MEMORY_BRIDGE chains these to close the sub-read goal. *) + REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lenrej) = 4 * curlen + 4 * lenrej`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + (* Goal: read(bytes(addr, 4*lenrej)) s35 = num_of_wordlist(REJ_SAMPLE(...)) + Use VPERMD_MEMORY_BRIDGE with the PRE-ENSURES bytes32 hypothesis. + First find the hypothesis, then build the full closing tactic. *) + (fun (asl,w) -> + (* Find bytes32 hyp, VPERMD MOD hyp, lenrej bound, then forward-chain *) + try + (* 1. bytes32 hypothesis: read(bytes(addr,32)) s35 = vr *) + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let bytes32_hyp = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (fun t -> try dest_numeral t = Num.num_of_int 32 with _ -> false)) (lhs(concl th)) && + can (find_term (fun t -> try fst(dest_var t) = "s35" with _ -> false)) (lhs(concl th)) && + can (find_term (fun t -> try fst(dest_var t) = "res" with _ -> false)) (lhs(concl th)) && + not(can (find_term (fun t -> try fst(dest_const t) = "bytes256" with _ -> false)) (lhs(concl th)))) asl) in + (* Find newlen = lenrej hypothesis *) + let newlen_eq = snd(List.find (fun (_,th) -> + try is_eq(concl th) && has_var "newlen" (lhs(concl th)) && + has_var "lenrej" (rhs(concl th)) + with _ -> false) asl) in + (* Find VPERMD MOD hyp: val(YMM3 sN) MOD 2^(32*newlen) = num_of_wordlist(...) + May be for s34 or s33 — find the most recent one *) + let vpermd_hyp_raw = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_var "newlen")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th)) asl) in + (* Normalize: replace newlen with lenrej *) + let vpermd_hyp_1 = REWRITE_RULE[newlen_eq] vpermd_hyp_raw in + (* The VPERMD hyp may use a different state (s34) than bytes32 (s35). + Bridge: find read YMM3 s35 = E and read YMM3 s34 = E, chain them. *) + let vpermd_hyp = try + (* Find read YMM3 sN = — require int256 RHS and YMM3 on LHS *) + let is_ymm3_read th = + is_eq(concl th) && + type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let ymm35 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var "s35")) (lhs(concl th))) asl) in + let ymm34 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var "s34")) (lhs(concl th))) asl) in + (* read YMM3 s35 = E, read YMM3 s34 = E => read YMM3 s35 = read YMM3 s34 *) + let ymm_eq = TRANS ymm35 (SYM ymm34) in + let val_eq = AP_TERM `val:int256->num` ymm_eq in + (* Rewrite s34 → s35 in the VPERMD hypothesis *) + REWRITE_RULE[GSYM val_eq] vpermd_hyp_1 + with _ -> + vpermd_hyp_1 in + (* 3. lenrej <= 8: directly available *) + let lenrej_bound = snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + has_var "lenrej" (lhand(concl th)) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) in + (* Forward chain: MATCH_MP VPERMD_MEMORY_BRIDGE (bytes32 /\ mod /\ bound) *) + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_hyp (CONJ vpermd_hyp lenrej_bound)) in + REWRITE_TAC[bridge] (asl,w) + with e -> + failwith "memstore bridge derivation failed")]; + + (* ~(i+1 < N): exit to pc+181. + Approach: instead of DISJ_CASES on the outer disjunct, first derive + N = i+1, then ASM_CASES on `248 < curlen + lenrej`: + * if true: count-exit fires (JAE at s37 to pc+181) — same as old J2 + * if false: offset-exit path — VSTEPS 38-39 do CMP ecx/JA exit + This avoids the artificial J1/J2 split that required a separate + offset-exit proof. *) + SUBGOAL_THEN `N = i + 1` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN ARITH_TAC; + ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (36--37) THEN + FIRST_X_ASSUM(DISJ_CASES_TAC o check (is_disj o concl)) THENL + [(* J1: offset exit (832 < 24 * (N + 1) disjunct holds). + Split on whether count-exit ALSO fires: + * J1-A: 248 < curlen + lr → JAE at s37 fires directly, same as J2. + * J1-B: curlen + lr <= 248 → JAE falls through, CMP ecx/JA at s38-39 + fires because 808 < 24*(i+1) (from disjunct + N=i+1). *) + ASM_CASES_TAC + `248 < curlen + LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` + THENL + [(* J1-A: JAE at s37 fires. Derive J2's precondition, run J2 body. *) + SUBGOAL_THEN + `248 < LENGTH(REJ_SAMPLE(SUB_LIST(0,8 * N) (inlist:(24 word)list)))` + ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN + ASM_REWRITE_TAC[ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + ADD_CLAUSES]; + ALL_TAC] THEN + (* J1-A body is identical to J2 body; run through it. + Must keep this in sync if J2 body changes. *) + SUBGOAL_THEN `newlen <= 8` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + (fun (asl,w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "bytes256")) (lhs(concl th)) && + not(can (find_term (has_const "MAYCHANGE")) (concl th))) asl) in + let b256_state = find_term (fun t -> + try let n = fst(dest_var t) in + String.length n > 1 && n.[0] = 's' && type_of t = `:x86state` + with _ -> false) (lhs(concl b256_th)) in + let b256_state_name = fst(dest_var b256_state) in + let ymm_th = snd(find (fun (_,th) -> + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + can (find_term (has_var b256_state_name)) (lhs(concl th))) asl) in + let b256_ymm3 = TRANS b256_th (SYM ymm_th) in + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let vwe32 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_32] (INST_TYPE [`:32`,`:N`] VAL_WORD_EQ)) in + let vwe64 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_64] (INST_TYPE [`:64`,`:N`] VAL_WORD_EQ)) in + let cn32 = MATCH_MP vwe32 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) + curlen_bound)) in + let cn64 = MATCH_MP vwe64 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) + curlen_bound)) in + let b256_norm = REWRITE_RULE[cn32; cn64] b256_ymm3 in + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + let vpermd_raw = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th) && + can (find_term (has_var "newlen")) (concl th)) asl) in + let is_ymm3_read th = + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let vpermd_states = List.filter (fun v -> + let n = fst(dest_var v) in String.length n > 1 && n.[0] = 's' && + type_of v = `:x86state`) (frees(concl vpermd_raw)) in + let vp_state_name = fst(dest_var(List.hd vpermd_states)) in + let vpermd = try + let ymm_b32 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var b256_state_name)) + (lhs(concl th))) asl) in + let ymm_vp = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var vp_state_name)) + (lhs(concl th))) asl) in + let ymm_eq = TRANS ymm_b32 (SYM ymm_vp) in + let val_eq = AP_TERM `val:int256->num` ymm_eq in + REWRITE_RULE[GSYM val_eq] vpermd_raw + with _ -> vpermd_raw in + let newlen_bound = snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + (has_var "newlen" (lhand(concl th))) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) in + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_eq (CONJ vpermd newlen_bound)) in + ASSUME_TAC bridge (asl,w) + with _ -> failwith "J1-A PRE-ENSURES") THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `N = i + 1` (fun th -> + REWRITE_TAC[th; ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + ARITH_RULE `24 * (i + 1) = 24 * i + 24`]) THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + NUM_OF_WORDLIST_APPEND] THEN + REWRITE_TAC[ADD_CLAUSES] THEN + ABBREV_TAC + `lr = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + SUBGOAL_THEN `lr <= 8` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN REWRITE_TAC[REJ_SAMPLE] THEN + TRANS_TAC LE_TRANS `LENGTH(MAP (\x:24 word. word(val x MOD 2 EXP 23):int32) (SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + REWRITE_TAC[LENGTH_FILTER; LENGTH_MAP; LENGTH_SUB_LIST] THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `248 < curlen + lr` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + FIRST_X_ASSUM(SUBST_ALL_TAC) THEN + (fun (asl, w) -> + try + let newlen_eq_lr = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + (try fst(dest_var(lhs c)) = "newlen" with _ -> false) && + (try fst(dest_var(rhs c)) = "lr" with _ -> false)) asl) in + RULE_ASSUM_TAC (REWRITE_RULE [newlen_eq_lr]) (asl, w) + with _ -> ALL_TAC (asl, w)) THEN + DISCARD_ASSUMPTIONS_TAC (fun th -> + let c = concl th in + let fvs = frees c in + let has_const name t = try fst(dest_const t) = name with _ -> false in + not(is_eq c && + can (find_term (has_const "read")) c && + can (find_term (has_const "bytes")) c) && + (not (forall (fun v -> type_of v = `:num`) fvs) || + exists (fun v -> let n = fst(dest_var v) in + n = "N" || n = "newlen" || n = "curlist") fvs || + is_forall c)) THEN + REPEAT CONJ_TAC THENL + [MATCH_MP_TAC(TAUT `p ==> (if p then a else b) = a`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `248 <= curlen + lr` ASSUME_TAC THENL + [UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(curlen + lr) - 248 < 18446744073709551616` + ASSUME_TAC THENL + [UNDISCH_TAC `curlen + lr < 18446744073709551616` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC; + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`,`n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ + n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lr) = 4 * curlen + 4 * lr`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + ASM_REWRITE_TAC[] THEN NO_TAC]; + + (* J1-B: JAE at s37 falls through to pc+111. VSTEPS 38-39 do CMP ecx/JA + and exit to pc+181 because 808 < 24*(i+1) (from offset disjunct). *) + ABBREV_TAC + `lr = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + SUBGOAL_THEN `lr <= 8` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN REWRITE_TAC[REJ_SAMPLE] THEN + TRANS_TAC LE_TRANS `LENGTH(MAP (\x:24 word. word(val x MOD 2 EXP 23):int32) (SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + REWRITE_TAC[LENGTH_FILTER; LENGTH_MAP; LENGTH_SUB_LIST] THEN + ARITH_TAC; + ALL_TAC] THEN + (* Resolve RIP s37 = word(pc + 111) via COND simplification *) + SUBGOAL_THEN `read RIP s37 = word(pc + 111):int64` MP_TAC THENL + [FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `read RIP s37`)) (concl th) && + is_eq(concl th) + then SUBST1_TAC th else failwith "") THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then a else b) = b`) THEN + REWRITE_TAC[DE_MORGAN_THM; NOT_CLAUSES; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `~(248 < curlen + lr)` THEN + ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN ASSUME_TAC THEN + FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + can (find_term ((=) `read RIP s37`)) c && is_eq c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c))) THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (38--39) THEN + (* Resolve RIP s39 = word(pc + 181) via JA analysis *) + SUBGOAL_THEN `read RIP s39 = word(pc + 181):int64` MP_TAC THENL + [FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `read RIP s39`)) (concl th) && + is_eq(concl th) + then SUBST1_TAC th else failwith "") THEN + MATCH_MP_TAC(TAUT `p ==> (if p then a else b) = a`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `24 * i < 4294967296 /\ 24 * i < 18446744073709551616 /\ + 24 * i + 24 < 4294967296 /\ + 24 * i + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `832 < 24 * (N + 1)` THEN + UNDISCH_TAC `N = i + 1` THEN ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN ASSUME_TAC THEN + FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + can (find_term ((=) `read RIP s39`)) c && is_eq c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c))) THEN + (* Rest mirrors J2's body, adapted for s39 *) + SUBGOAL_THEN `newlen <= 8` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + (fun (asl,w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "bytes256")) (lhs(concl th)) && + not(can (find_term (has_const "MAYCHANGE")) (concl th))) asl) in + let b256_state = find_term (fun t -> + try let n = fst(dest_var t) in + String.length n > 1 && n.[0] = 's' && type_of t = `:x86state` + with _ -> false) (lhs(concl b256_th)) in + let b256_state_name = fst(dest_var b256_state) in + let ymm_th = snd(find (fun (_,th) -> + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + can (find_term (has_var b256_state_name)) (lhs(concl th))) asl) in + let b256_ymm3 = TRANS b256_th (SYM ymm_th) in + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let vwe32 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_32] (INST_TYPE [`:32`,`:N`] VAL_WORD_EQ)) in + let vwe64 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_64] (INST_TYPE [`:64`,`:N`] VAL_WORD_EQ)) in + let cn32 = MATCH_MP vwe32 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) + curlen_bound)) in + let cn64 = MATCH_MP vwe64 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) + curlen_bound)) in + let b256_norm = REWRITE_RULE[cn32; cn64] b256_ymm3 in + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + let vpermd_raw = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th) && + can (find_term (has_var "newlen")) (concl th)) asl) in + let is_ymm3_read th = + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let vpermd_states = List.filter (fun v -> + let n = fst(dest_var v) in String.length n > 1 && n.[0] = 's' && + type_of v = `:x86state`) (frees(concl vpermd_raw)) in + let vp_state_name = fst(dest_var(List.hd vpermd_states)) in + let vpermd = try + let ymm_b32 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var b256_state_name)) + (lhs(concl th))) asl) in + let ymm_vp = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var vp_state_name)) + (lhs(concl th))) asl) in + let ymm_eq = TRANS ymm_b32 (SYM ymm_vp) in + let val_eq = AP_TERM `val:int256->num` ymm_eq in + REWRITE_RULE[GSYM val_eq] vpermd_raw + with _ -> vpermd_raw in + let newlen_bound = snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + (has_var "newlen" (lhand(concl th))) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) in + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_eq (CONJ vpermd newlen_bound)) in + ASSUME_TAC bridge (asl,w) + with _ -> failwith "J1-B PRE-ENSURES") THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `N = i + 1` (fun th -> + REWRITE_TAC[th; ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + ARITH_RULE `24 * (i + 1) = 24 * i + 24`]) THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + NUM_OF_WORDLIST_APPEND] THEN + REWRITE_TAC[ADD_CLAUSES] THEN + (* lr already abbreviated in J1-B prologue *) + ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`,`n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ + n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lr) = 4 * curlen + 4 * lr`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + (fun (asl, w) -> + try + let newlen_eq_lr = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + (try fst(dest_var(lhs c)) = "newlen" with _ -> false) && + (try fst(dest_var(rhs c)) = "lr" with _ -> false)) asl) in + RULE_ASSUM_TAC (REWRITE_RULE [newlen_eq_lr]) (asl, w) + with _ -> ALL_TAC (asl, w)) THEN + ASM_REWRITE_TAC[] THEN NO_TAC]]; + (* J2: 248 < LENGTH(REJ_SAMPLE(SUB_LIST(0,8*N))): count exit *) + (* Prelude: establish newlen <= 8, needed by PRE-ENSURES *) + SUBGOAL_THEN `newlen <= 8` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + (* PRE-ENSURES: derive full VPERMD bridge result before ENSURES_FINAL_STATE_TAC. + Build: read(bytes(word_add res (word(4*curlen)), 4*newlen)) sN = + num_of_wordlist(REJ_SAMPLE(SUB_LIST(8*i,8) inlist)) + so that ASM_REWRITE_TAC can use it after ENSURES_FINAL_STATE_TAC *) + (fun (asl,w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + (* 1. Derive bytes32 eq: read(bytes(addr,32)) sN = val(YMM3 sN) *) + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "bytes256")) (lhs(concl th)) && + not(can (find_term (has_const "MAYCHANGE")) (concl th))) asl) in + let b256_state = find_term (fun t -> + try let n = fst(dest_var t) in + String.length n > 1 && n.[0] = 's' && type_of t = `:x86state` + with _ -> false) (lhs(concl b256_th)) in + let b256_state_name = fst(dest_var b256_state) in + let ymm_th = snd(find (fun (_,th) -> + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + can (find_term (has_var b256_state_name)) (lhs(concl th))) asl) in + let b256_ymm3 = TRANS b256_th (SYM ymm_th) in + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let vwe32 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_32] (INST_TYPE [`:32`,`:N`] VAL_WORD_EQ)) in + let vwe64 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_64] (INST_TYPE [`:64`,`:N`] VAL_WORD_EQ)) in + let cn32 = MATCH_MP vwe32 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) curlen_bound)) in + let cn64 = MATCH_MP vwe64 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) curlen_bound)) in + let b256_norm = REWRITE_RULE[cn32; cn64] b256_ymm3 in + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + (* 2. Get VPERMD MOD hyp and bridge states *) + let vpermd_raw = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th) && + can (find_term (has_var "newlen")) (concl th)) asl) in + let is_ymm3_read th = + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let vpermd_states = List.filter (fun v -> + let n = fst(dest_var v) in String.length n > 1 && n.[0] = 's' && + type_of v = `:x86state`) (frees(concl vpermd_raw)) in + let vp_state_name = fst(dest_var(List.hd vpermd_states)) in + let vpermd = try + let ymm_b32 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var b256_state_name)) (lhs(concl th))) asl) in + let ymm_vp = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var vp_state_name)) (lhs(concl th))) asl) in + let ymm_eq = TRANS ymm_b32 (SYM ymm_vp) in + REWRITE_RULE[GSYM(AP_TERM `val:int256->num` ymm_eq)] vpermd_raw + with _ -> vpermd_raw in + (* 3. Get newlen <= 8 *) + let newlen_bound = snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + (has_var "newlen" (lhand(concl th))) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) in + (* 4. Forward chain VPERMD_MEMORY_BRIDGE *) + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_eq (CONJ vpermd newlen_bound)) in + ASSUME_TAC bridge (asl,w) + with _ -> ALL_TAC (asl,w)) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + (* Substitute N = i+1 *) + SUBGOAL_THEN `N = i + 1` (fun th -> + REWRITE_TAC[th; ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + ARITH_RULE `24 * (i + 1) = 24 * i + 24`]) THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + NUM_OF_WORDLIST_APPEND] THEN + REWRITE_TAC[ADD_CLAUSES] THEN + ABBREV_TAC `lr = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + SUBGOAL_THEN `lr <= 8` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN REWRITE_TAC[REJ_SAMPLE] THEN + TRANS_TAC LE_TRANS `LENGTH(MAP (\x:24 word. word(val x MOD 2 EXP 23):32 word) (SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + REWRITE_TAC[LENGTH_FILTER; LENGTH_MAP; LENGTH_SUB_LIST] THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `248 < curlen + lr` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + can (find_term (fun t -> try fst(dest_const t) = "REJ_SAMPLE" with _ -> false)) (concl th) && + can (find_term (fun t -> try dest_small_numeral t > 200 with _ -> false)) (concl th))) THEN + SUBGOAL_THEN `N = i + 1` (fun th -> REWRITE_TAC[th; ARITH_RULE `8 * (i + 1) = 8 * i + 8`]) THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND] THEN + REWRITE_TAC[ADD_CLAUSES] THEN EXPAND_TAC "lr" THEN ARITH_TAC; + ALL_TAC] THEN + FIRST_X_ASSUM(SUBST_ALL_TAC) THEN + (* Rewrite bridge with newlen = lr BEFORE DISCARD eats the connection *) + (fun (asl, w) -> + try + let newlen_eq_lr = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + (try fst(dest_var(lhs c)) = "newlen" with _ -> false) && + (try fst(dest_var(rhs c)) = "lr" with _ -> false)) asl) in + RULE_ASSUM_TAC (REWRITE_RULE [newlen_eq_lr]) (asl, w) + with _ -> ALL_TAC (asl, w)) THEN + DISCARD_ASSUMPTIONS_TAC (fun th -> + let c = concl th in + let fvs = frees c in + let has_const name t = try fst(dest_const t) = name with _ -> false in + (* Keep: bridge (REJ_SAMPLE/read/bytes) AND invariant (read/bytes, curlist RHS) *) + not(is_eq c && + can (find_term (has_const "read")) c && + can (find_term (has_const "bytes")) c) && + (not (forall (fun v -> type_of v = `:num`) fvs) || + exists (fun v -> let n = fst(dest_var v) in + n = "N" || n = "newlen" || n = "curlist") fvs || + is_forall c)) THEN + REPEAT CONJ_TAC THENL + [(* 1. JA conditional. + Use targeted UNDISCH instead of ASM_ARITH_TAC to avoid hanging + on the ~200-assumption list. *) + MATCH_MP_TAC(TAUT `p ==> (if p then a else b) = a`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `248 <= curlen + lr` ASSUME_TAC THENL + [UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(curlen + lr) - 248 < 18446744073709551616` ASSUME_TAC THENL + [UNDISCH_TAC `curlen + lr < 18446744073709551616` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC; + (* 2. RAX *) + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + (* 3. RCX *) + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`,`n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + (* 4. Memory store — bridge theorem from PRE-ENSURES is in assumptions *) + REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lr) = 4 * curlen + 4 * lr`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + ASM_REWRITE_TAC[] THEN NO_TAC]]]; + + (* ================================================================= *) + (* Subgoal 3: Post-loop *) + (* ================================================================= *) + (* ================================================================= *) + (* Subgoal 3: Post-loop (scalar loop + VZEROUPPER + RET) *) + (* *) + (* Entry: pc+181 with REJ_SAMPLE(SUB_LIST(0,8*N)) accumulated. *) + (* Code structure: *) + (* pc+181: CMP eax,256; JAE vzeroupper *) + (* pc+188: CMP ecx,837; JA vzeroupper *) + (* pc+196..240: scalar coefficient loop (≤ 8 iterations) *) + (* pc+242: VZEROUPPER *) + (* *) + (* Preparation: abbreviate outlist/outlen, establish bounds. *) + (* ================================================================= *) + CONV_TAC(RATOR_CONV(LAND_CONV(TOP_DEPTH_CONV let_CONV))) THEN + MAP_EVERY ABBREV_TAC + [`outlist = REJ_SAMPLE (SUB_LIST (0,8 * N) inlist)`; + `outlen = LENGTH(outlist:int32 list)`] THEN + (* Save LENGTH(REJ_SAMPLE(...)) = outlen before ABBREV consumes the connection *) + SUBGOAL_THEN + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) (inlist:(24 word)list))) = outlen` + ASSUME_TAC THENL + [UNDISCH_TAC `REJ_SAMPLE (SUB_LIST (0,8 * N) (inlist:(24 word)list)) = outlist` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN + UNDISCH_TAC `LENGTH (outlist:int32 list) = outlen` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]); + ALL_TAC] THEN + (* Derive 24*N <= 832 and LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*(N-1)))) <= 248 *) + SUBGOAL_THEN + `24 * N <= 832 /\ + LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * (N - 1)) (inlist:(24 word)list))) <= 248` + STRIP_ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o SPEC `N - 1`) THEN + ANTS_TAC THENL [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(N - 1) + 1 = N` SUBST1_TAC THENL + [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[]; + ALL_TAC] THEN + (* Derive outlen <= 256 via SIMD_ITERATION_BRIDGE at (N-1) *) + SUBGOAL_THEN `outlen <= 256` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `N - 1`; + `REJ_SAMPLE(SUB_LIST(0, 8*(N-1)) (inlist:(24 word)list))`; + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*(N-1)) (inlist:(24 word)list)))`] + SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [REWRITE_TAC[] THEN + SUBGOAL_THEN `N - 1 + 1 = N` SUBST1_TAC THENL + [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + UNDISCH_TAC `LENGTH (inlist:(24 word)list) = 280` THEN + UNDISCH_TAC `24 * N <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `N - 1 + 1 = N` SUBST1_TAC THENL + [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN + UNDISCH_TAC + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) (inlist:(24 word)list))) = outlen` THEN + UNDISCH_TAC + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N) (inlist:(24 word)list))) = + LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * (N - 1)) inlist)) + + LENGTH (REJ_SAMPLE (SUB_LIST (8 * (N - 1),8) inlist))` THEN + UNDISCH_TAC + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * (N - 1)) (inlist:(24 word)list))) <= 248` THEN + UNDISCH_TAC + `LENGTH (REJ_SAMPLE (SUB_LIST (8 * (N - 1),8) (inlist:(24 word)list))) <= 8` THEN + ARITH_TAC; + ALL_TAC] THEN + (* Characterize the number of scalar iterations K via WOP. + K = smallest j such that LENGTH(REJ_SAMPLE(SUB_LIST(0,8*N+j))) >= 256 OR 837 < 24*N+3*j. *) + SUBGOAL_THEN + `?j. 256 <= LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N + j) (inlist:(24 word)list))) \/ + 837 < 24 * N + 3 * j` + MP_TAC THENL + [EXISTS_TAC `280:num` THEN DISJ2_TAC THEN + UNDISCH_TAC `24 * N <= 832` THEN ARITH_TAC; + GEN_REWRITE_TAC LAND_CONV [num_WOP]] THEN + DISCH_THEN(X_CHOOSE_THEN `K:num` (CONJUNCTS_THEN2 ASSUME_TAC MP_TAC)) THEN + DISCH_THEN(fun th -> + ASSUME_TAC(REWRITE_RULE[DE_MORGAN_THM; NOT_LE; NOT_LT] th)) THEN + (* Case split: K = 0 (no scalar iterations) vs K > 0 (scalar loop) *) + ASM_CASES_TAC `K = 0` THENL + [ (* K = 0: Must have outlen = 256 (since 24*N <= 832 rules out offset exit at K=0). *) + SUBGOAL_THEN `outlen = 256` ASSUME_TAC THENL + [RULE_ASSUM_TAC(REWRITE_RULE[ASSUME `K = 0`; ADD_CLAUSES; MULT_CLAUSES]) THEN + UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N) (inlist:(24 word)list))) \/ + 837 < 24 * N` THEN + UNDISCH_TAC + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) (inlist:(24 word)list))) = outlen` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN + UNDISCH_TAC `outlen <= 256` THEN + UNDISCH_TAC `24 * N <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + (* Apply case A proof: JAE fires to pc+242 *) + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [40;41] THEN + RULE_ASSUM_TAC(REWRITE_RULE[ASSUME `outlen = 256`]) THEN + FIRST_X_ASSUM(fun th -> + if (try let s = string_of_term (concl th) in String.length s > 20 && + String.sub s 0 11 = "read RIP s4" with _ -> false) && + is_eq(concl th) + then ASSUME_TAC(CONV_RULE(RAND_CONV(DEPTH_CONV WORD_NUM_RED_CONV)) th) + else failwith "not RIP") THEN + (* vzeroupper removed (was step 55); RIP is already at the RET. *) + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) = + REJ_SAMPLE (SUB_LIST (0, 8 * N) inlist)` + ASSUME_TAC THENL + [MATCH_MP_TAC REJ_SAMPLE_PREFIX_256 THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + UNDISCH_TAC + `REJ_SAMPLE (SUB_LIST (0,8 * N) (inlist:(24 word)list)) = outlist` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN + UNDISCH_TAC `LENGTH (outlist:int32 list) = outlen` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN + ASM_REWRITE_TAC[]; + (* K > 0: scalar loop runs K times. Use ENSURES_WHILE_UP2_TAC. *) + ENSURES_WHILE_UP2_TAC `K:num` `pc + 181` `pc + 242` + `\i s. read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + (let outlist_i = REJ_SAMPLE(SUB_LIST(0, 8 * N + i) (inlist:(24 word)list)) in + let outlen_i = LENGTH outlist_i in + read RAX s = word outlen_i /\ + read RCX s = word(24 * N + 3 * i) /\ + read(memory :> bytes(res, 4 * outlen_i)) s = num_of_wordlist outlist_i)` THEN + ASM_REWRITE_TAC[] THEN REPEAT CONJ_TAC THENL + [ (* Init: precond -> invariant @ 0 *) + ENSURES_INIT_TAC "s0" THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[ADD_CLAUSES; MULT_CLAUSES] THEN ASM_REWRITE_TAC[]; + (* Body: invariant @ i -> invariant @ (i+1) at pc+181 or pc+242. + Discharge via SCALAR_BODY_LEMMA, which packages the full iteration: + CMP eax,256; JAE (not taken), CMP ecx,837; JA (not taken), + MOVZX + OR + AND + ADD + CMP Q + JAE (either back or continue), + MOV + ADD + JMP back. + The wrapper specializes N,K,i,inlist,res,buf,table,pc,stackpointer. + The (forall j < K. ...) precondition is discharged from the WOP + assumption `!m. m < K ==> ~(256 <= LENGTH(...)) /\ ~(837 < 24*N+3*m)`. *) + X_GEN_TAC `i:num` THEN STRIP_TAC THEN + REWRITE_TAC[GSYM SOME_FLAGS] THEN + MATCH_MP_TAC SCALAR_BODY_LEMMA THEN + ASM_REWRITE_TAC[NONOVERLAPPING_CLAUSES] THEN + CONJ_TAC THENL + [X_GEN_TAC `j:num` THEN DISCH_TAC THEN + FIRST_X_ASSUM(MP_TAC o SPEC `j:num` o check (is_forall o concl)) THEN + ASM_REWRITE_TAC[] THEN ARITH_TAC; + (* WOP disjunct at K *) + FIRST_X_ASSUM(MATCH_ACCEPT_TAC o check (fun th -> + let c = concl th in is_disj c && + can (find_term ((=) `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))`)) c))]; + (* Post: invariant @ K -> postcondition. + At i=K, exit condition fires. RIP = pc+242 (vzeroupper). *) + ENSURES_INIT_TAC "s0" THEN + (* Break out the invariant conjunction *) + RULE_ASSUM_TAC(CONV_RULE(TOP_DEPTH_CONV let_CONV)) THEN + FIRST_X_ASSUM(fun th -> + let c = concl th in + if is_conj c && (try can (find_term ((=) `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))`)) c with _ -> false) + then STRIP_ASSUME_TAC th else failwith "not inv") THEN + (* vzeroupper removed (was step 55); RIP is already at the RET. *) + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + (* The disjunct at K: either count-exit (256 <= outlen_K) or offset-exit (837 < 24*N+3*K) *) + FIRST_X_ASSUM(DISJ_CASES_TAC o check (is_disj o concl)) THENL + [ (* count-exit case: 256 <= outlen_K. Since outlen is monotonic +0/+1 per scalar iter, + and outlen_{K-1} < 256 (from WOP), we have outlen_K = 256. *) + SUBGOAL_THEN + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256` + ASSUME_TAC THENL + [(* Monotonicity: LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*N + K-1))) < 256 (from WOP) + and REJ_SAMPLE_STEP_LE gives LENGTH(...K) <= LENGTH(...K-1) + 1 <= 256. + Combined with 256 <= LENGTH(...K) gives equality. *) + FIRST_X_ASSUM(MP_TAC o SPEC `K - 1`) THEN + ANTS_TAC THENL [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + K - 1`] REJ_SAMPLE_STEP_LE) THEN + SUBGOAL_THEN `(8 * N + K - 1) + 1 = 8 * N + K` SUBST1_TAC THENL + [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) = + REJ_SAMPLE (SUB_LIST (0, 8 * N + K) inlist)` + ASSUME_TAC THENL + [MATCH_MP_TAC REJ_SAMPLE_PREFIX_256 THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = + REJ_SAMPLE (SUB_LIST (0,8 * N + K) inlist)` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN ASM_REWRITE_TAC[LE_REFL]; + ALL_TAC] THEN + (* Rewrite memory hyp using LENGTH = 256 *) + RULE_ASSUM_TAC(REWRITE_RULE[ASSUME + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256`]) THEN + ASM_REWRITE_TAC[]; + (* offset-exit case: 837 < 24*N+3*K. Need to handle whether count-exit also fires. *) + ASM_CASES_TAC + `256 <= LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N + K) (inlist:(24 word)list)))` + THENL + [ (* Both conditions: 256 <= outlen_K. Derive outlen_K = 256 via monotonicity, + then reduce to case A. *) + SUBGOAL_THEN + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256` + ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o SPEC `K - 1`) THEN + ANTS_TAC THENL [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + K - 1`] REJ_SAMPLE_STEP_LE) THEN + SUBGOAL_THEN `(8 * N + K - 1) + 1 = 8 * N + K` SUBST1_TAC THENL + [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) = + REJ_SAMPLE (SUB_LIST (0, 8 * N + K) inlist)` + ASSUME_TAC THENL + [MATCH_MP_TAC REJ_SAMPLE_PREFIX_256 THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = + REJ_SAMPLE (SUB_LIST (0,8 * N + K) inlist)` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN ASM_REWRITE_TAC[LE_REFL]; + ALL_TAC] THEN + RULE_ASSUM_TAC(REWRITE_RULE[ASSUME + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256`]) THEN + ASM_REWRITE_TAC[]; + (* Only offset-exit: outlen_K < 256 and 24*N+3*K > 837. + Then 8*N+K >= 280 (bytes consumed past input), so SUB_LIST = inlist. *) + SUBGOAL_THEN `SUB_LIST (0, 8 * N + K) (inlist:(24 word)list) = inlist` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN + UNDISCH_TAC `LENGTH (inlist:(24 word)list) = 280` THEN + UNDISCH_TAC `837 < 24 * N + 3 * K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `LENGTH (REJ_SAMPLE (inlist:(24 word)list)) <= 256` + ASSUME_TAC THENL + [UNDISCH_TAC + `~(256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))))` THEN + SUBGOAL_THEN `SUB_LIST (0, 8 * N + K) (inlist:(24 word)list) = inlist` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN + UNDISCH_TAC `LENGTH (inlist:(24 word)list) = 280` THEN + UNDISCH_TAC `837 < 24 * N + 3 * K` THEN ARITH_TAC; + ALL_TAC] THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) = REJ_SAMPLE inlist` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + REWRITE_TAC[] THEN + (* Memory closure: rewrite SUB_LIST = inlist in the memory hypothesis and accept. + We have to build the SUB_LIST_REFL fact without `prove` (which starts a fresh + proof without access to current asl hypotheses). *) + (fun (asl, w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let mem_hyp = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + can (find_term (has_const "REJ_SAMPLE")) c && + can (find_term (has_const "bytes")) c && + can (find_term (has_const "memory")) c && + can (find_term (has_var "res")) c) asl) in + let len280 = snd(List.find (fun (_, th) -> + concl th = `LENGTH (inlist:(24 word)list) = 280`) asl) in + let off837 = snd(List.find (fun (_, th) -> + concl th = `837 < 24 * N + 3 * K`) asl) in + let bound_th = MP (MP + (ARITH_RULE `LENGTH (inlist:(24 word)list) = 280 + ==> 837 < 24 * N + 3 * K + ==> LENGTH inlist <= 8 * N + K`) len280) off837 in + let sub_eq = MATCH_MP + (ISPECL [`inlist:(24 word)list`; `8 * N + K`] SUB_LIST_REFL) + bound_th in + let mem_hyp' = REWRITE_RULE[sub_eq] mem_hyp in + ACCEPT_TAC mem_hyp' (asl, w) + with _ -> failwith "memory finalize failed")]]]]]);; + +(* ========================================================================= *) +(* SUBROUTINE_CORRECT variants (standard x86_64 ABI). *) +(* ========================================================================= *) + +let MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_CORRECT = prove + (`!res buf table (inlist:(24 word)list) pc stackpointer returnaddress. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_tmc) (res, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_tmc) (buf, 840) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_tmc) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (res, 1024) /\ + nonoverlapping (stackpointer, 8) (buf, 840) /\ + nonoverlapping (stackpointer, 8) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (word pc, LENGTH mldsa_rej_uniform_tmc) + ==> ensures x86 + (\s. bytes_loaded s (word pc) mldsa_rej_uniform_tmc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (let outlist = SUB_LIST(0,256) (REJ_SAMPLE inlist) in + let outlen = LENGTH outlist in + C_RETURN s = word outlen /\ + read(memory :> bytes(res,4 * outlen)) s = + num_of_wordlist outlist)) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + X86_PROMOTE_RETURN_NOSTACK_TAC mldsa_rej_uniform_tmc + MLDSA_REJ_UNIFORM_CORRECT);; + +let MLDSA_REJ_UNIFORM_SUBROUTINE_CORRECT = prove + (`!res buf table (inlist:(24 word)list) pc stackpointer returnaddress. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_mc) (res, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_mc) (buf, 840) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_mc) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (res, 1024) /\ + nonoverlapping (stackpointer, 8) (buf, 840) /\ + nonoverlapping (stackpointer, 8) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (word pc, LENGTH mldsa_rej_uniform_mc) + ==> ensures x86 + (\s. bytes_loaded s (word pc) mldsa_rej_uniform_mc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (let outlist = SUB_LIST(0,256) (REJ_SAMPLE inlist) in + let outlen = LENGTH outlist in + C_RETURN s = word outlen /\ + read(memory :> bytes(res,4 * outlen)) s = + num_of_wordlist outlist)) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + MATCH_ACCEPT_TAC(ADD_IBT_RULE MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_CORRECT));; + +(* ========================================================================= *) +(* Memory safety for the non-constant-time rejection sampling. *) +(* *) +(* We prove memory safety alone, _not_ constant-time, for *) +(* mldsa_rej_uniform. This function operates on public data only, hence *) +(* constant-time'ness is not a requirement. Allowing for variable-time *) +(* execution enables a faster implementation. *) +(* *) +(* The standard _SAFE_ proof pattern *) +(* (exists f_events. forall ... e2 = f_events ) cannot *) +(* be used here because the microarchitectural events depend on private *) +(* data (which input bytes pass the < MLDSA_Q filter). *) +(* ========================================================================= *) + +needs "x86/proofs/consttime.ml";; + +(* Helper: discharge the memsafe postcondition + exists e2. read events s = APPEND e2 e /\ memaccess_inbounds e2 R W + after symbolic simulation, using accumulated events from the invariant. + This is DISCHARGE_SAFETY_PROPERTY_TAC minus the f_events unification. *) +let DISCHARGE_MEMSAFE_TAC:tactic = + SAFE_META_EXISTS_TAC allowed_vars_e THEN + CONJ_TAC THENL [ EXISTS_E2_TAC allowed_vars_e; ALL_TAC ] THEN + DISCHARGE_MEMACCESS_INBOUNDS_TAC;; + +(* Like SIMPLE_ARITH_TAC but allows `val` in assumptions since + contained_modulo bounds may involve val terms. Filters out + read/write/word simulation cruft that makes ASM_ARITH_TAC slow. *) +let (MEMSAFE_ARITH_TAC:tactic) = + let numty = `:num` in + let is_num_relop tm = + exists (fun op -> is_binary op tm && + (let x,_ = dest_binary op tm in type_of x = numty)) + ["=";"<";"<=";">";">="] + and avoiders = ["lowdigits"; "highdigits"; "bigdigit"; + "read"; "write"; "word"] in + let avoiderp tm = + match tm with Const(n,_) -> mem n avoiders | _ -> false in + let filtered tm = + (is_num_relop tm || (is_neg tm && is_num_relop (dest_neg tm))) && + not(can (find_term avoiderp) tm) in + let tweak = GEN_REWRITE_RULE TRY_CONV [ARITH_RULE `~(n = 0) <=> 1 <= n`] in + W(fun (asl,w) -> + let asl' = filter (fun (_,th) -> filtered(concl th)) asl in + MAP_EVERY (MP_TAC o tweak o snd) asl' THEN CONV_TAC ARITH_RULE);; + +(* Bring `bitval p <= 1` as a MP_TAC hypothesis so MEMSAFE_ARITH_TAC's + ARITH_RULE can derive bounds on bitval-sum expressions arising from + VPMOVMSKPS-derived table indices. *) +let MEMSAFE_BITVAL_TAC:tactic = + W(fun (asl,w) -> + let bvs = find_terms (fun t -> + try fst(dest_const(rator t)) = "bitval" with _ -> false) w in + let bvs = setify bvs in + MAP_EVERY (fun bv -> + MP_TAC(SPEC (rand bv) BITVAL_BOUND)) bvs);; + +(* ASM-aware version of CONTAINED_TAC for loop-body proofs where + memory addresses involve symbolic loop variables. Uses MEMSAFE_ARITH_TAC + which filters assumptions to avoid the performance issues of ASM_ARITH_TAC + with hundreds of symbolic simulation assumptions. *) +let CONTAINED_ASM_TAC = + GEN_REWRITE_TAC I [GSYM CONTAINED_MODULO_MOD2] THEN + GEN_REWRITE_TAC (BINOP_CONV o LAND_CONV o LAND_CONV o TOP_DEPTH_CONV) + [VAL_WORD_ADD; VAL_WORD; DIMINDEX_64] THEN + CONV_TAC(BINOP_CONV(LAND_CONV MOD_DOWN_CONV)) THEN + REWRITE_TAC[CONTAINED_MODULO_MOD2; CONTAINED_MODULO_LMOD] THEN + ((GEN_REWRITE_TAC I [CONTAINED_MODULO_REFL] THEN + MEMSAFE_BITVAL_TAC THEN MEMSAFE_ARITH_TAC) ORELSE + (MATCH_MP_TAC CONTAINED_MODULO_OFFSET_SIMPLE THEN + MEMSAFE_BITVAL_TAC THEN MEMSAFE_ARITH_TAC) ORELSE + (MATCH_MP_TAC CONTAINED_MODULO_SIMPLE THEN + MEMSAFE_BITVAL_TAC THEN MEMSAFE_ARITH_TAC));; + +(* Variant of DISCARD_OLDSTATE_TAC that preserves hypotheses about + `read events sN` regardless of state references inside their RHS. + Needed because the SIMD loop body's POPCNT operand transitively + references `read (memory :> bytes256 buf) s4`, which would otherwise + cause the whole events chain to be erased. *) +let DISCARD_OLDSTATE_KEEP_EVENTS_TAC (s:string) = + let v = mk_var(s, `:x86state`) in + let rec unbound_statevars_of_read bound_svars tm = + match tm with + Comb(Comb(Const("read",_),cmp),s) -> + if mem s bound_svars then [] else [s] + | Comb(a,b) -> union (unbound_statevars_of_read bound_svars a) + (unbound_statevars_of_read bound_svars b) + | Abs(v,t) -> unbound_statevars_of_read (v::bound_svars) t + | _ -> [] in + let is_events_hyp tm = + is_eq tm && + (try let l = lhs tm in + let f, args = strip_comb l in + fst(dest_const f) = "read" && + List.length args = 2 && + fst(dest_const(List.hd args)) = "events" + with _ -> false) in + DISCARD_ASSUMPTIONS_TAC( + fun thm -> + if is_events_hyp (concl thm) then false + else + let us = unbound_statevars_of_read [] (concl thm) in + if us = [] || us = [v] then false + else if not(mem v us) then true + else true);; + +(* ASM-aware version of DISCHARGE_MEMSAFE_TAC for loop bodies. + Uses CONTAINED_ASM_TAC for contained_modulo proofs with symbolic bounds. *) +let DISCHARGE_MEMSAFE_ASM_TAC:tactic = + SAFE_META_EXISTS_TAC allowed_vars_e THEN + CONJ_TAC THENL [ EXISTS_E2_TAC allowed_vars_e; ALL_TAC ] THEN + REWRITE_TAC[MEMACCESS_INBOUNDS_APPEND] THEN + CONJ_TAC THENL + [REWRITE_TAC[memaccess_inbounds; ALL; EX; FST; SND] THEN + REPEAT CONJ_TAC THEN + TRY(REPEAT ((DISJ1_TAC THEN CONTAINED_ASM_TAC) ORELSE DISJ2_TAC ORELSE + CONTAINED_ASM_TAC) THEN NO_TAC); + REWRITE_TAC[APPEND; APPEND_NIL] THEN + FIRST_ASSUM ACCEPT_TAC];; + +let SCALAR_BODY_LEMMA_MEMSAFE = prove + (`!res buf table (inlist:(24 word)list) e pc stackpointer N K i. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, 243) (res, 1024) /\ + nonoverlapping (word pc, 243) (buf, 840) /\ + nonoverlapping (word pc, 243) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + 24 * N <= 832 /\ + LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) inlist)) <= 256 /\ + i < K /\ ~(i = K) /\ 0 < K /\ + (!j. j < K + ==> LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*N + j) inlist)) < 256 /\ + 24 * N + 3 * j <= 837) /\ + (256 <= LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*N + K) inlist)) \/ + 837 < 24 * N + 3 * K) + ==> + ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST mldsa_rej_uniform_tmc) /\ + read RIP s = word(pc + 181) /\ + read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + (let outlist_i = REJ_SAMPLE(SUB_LIST(0, 8 * N + i) (inlist:(24 word)list)) in + let outlen_i = LENGTH outlist_i in + read RAX s = word outlen_i /\ + read RCX s = word(24 * N + 3 * i) /\ + read(memory :> bytes(res, 4 * outlen_i)) s = num_of_wordlist outlist_i) /\ + (exists e_acc. read events s = APPEND e_acc e /\ + memaccess_inbounds e_acc + [buf,840; table,2048] + [res,1024])) + (\s. bytes_loaded s (word pc) (BUTLAST mldsa_rej_uniform_tmc) /\ + read RIP s = word (if i + 1 < K then pc + 181 else pc + 242) /\ + read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + (let outlist_j = REJ_SAMPLE(SUB_LIST(0, 8 * N + (i+1)) (inlist:(24 word)list)) in + let outlen_j = LENGTH outlist_j in + read RAX s = word outlen_j /\ + read RCX s = word(24 * N + 3 * (i+1)) /\ + read(memory :> bytes(res, 4 * outlen_j)) s = num_of_wordlist outlist_j) /\ + (exists e_acc. read events s = APPEND e_acc e /\ + memaccess_inbounds e_acc + [buf,840; table,2048] + [res,1024])) + (MAYCHANGE [RIP; RAX; RCX; R8; R9; R10] ,, + MAYCHANGE [ZMM0; ZMM1; ZMM2; ZMM3; ZMM4; + ZMM5; ZMM6; ZMM7; ZMM8; ZMM9; ZMM10; ZMM11; + ZMM12; ZMM13; ZMM14; ZMM15] ,, + MAYCHANGE SOME_FLAGS ,, MAYCHANGE [events] ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + REPEAT GEN_TAC THEN REWRITE_TAC[NONOVERLAPPING_CLAUSES] THEN + (* Split the precondition conjunction: strip all conjuncts EXCEPT the final + disjunction (which we keep as a single assumption for late use). *) + DISCH_THEN(CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC + (CONJUNCTS_THEN2 ASSUME_TAC ASSUME_TAC)))))))))))) THEN + FIRST_X_ASSUM(MP_TAC o C MATCH_MP (ASSUME `i:num < K`) o + check (is_forall o concl)) THEN STRIP_TAC THEN + ABBREV_TAC `curlist = REJ_SAMPLE(SUB_LIST(0, 8 * N + i) (inlist:(24 word)list))` THEN + ABBREV_TAC `curlen = LENGTH(curlist:int32 list)` THEN + SUBGOAL_THEN `curlen < 256` ASSUME_TAC THENL + [MAP_EVERY EXPAND_TAC ["curlen"; "curlist"] THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + CONV_TAC(RATOR_CONV(LAND_CONV(TOP_DEPTH_CONV let_CONV))) THEN + ASM_REWRITE_TAC[] THEN + ENSURES_INIT_TAC "s0" THEN STRIP_EXISTS_ASSUM_TAC THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [36;37] THEN + SUBGOAL_THEN `read RIP s37 = word(pc + 188):int64` + (fun th -> TRY(FIRST_X_ASSUM(K ALL_TAC o check (fun th' -> + let c = concl th' in is_eq c && can (find_term is_cond) c && + can (find_term ((=) `s37:x86state`)) c && + can (find_term ((=) `RIP`)) c))) THEN ASSUME_TAC th) THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + can (find_term ((=) `RIP`)) (concl th) && is_eq(concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[REAL_EQ_SUB_RADD; REAL_OF_NUM_ADD; REAL_OF_NUM_EQ] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [38;39] THEN + SUBGOAL_THEN `read RIP s39 = word(pc + 196):int64` + (fun th -> TRY(FIRST_X_ASSUM(K ALL_TAC o check (fun th' -> + let c = concl th' in is_eq c && can (find_term is_cond) c && + can (find_term ((=) `s39:x86state`)) c && + can (find_term ((=) `RIP`)) c))) THEN ASSUME_TAC th) THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + can (find_term ((=) `RIP`)) (concl th) && is_eq(concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[REAL_EQ_SUB_RADD; REAL_OF_NUM_ADD; REAL_OF_NUM_EQ] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (40--46) THEN + ABBREV_TAC `r8val:int64 = read R8 s46` THEN + ASM_CASES_TAC `val(r8val:int64) < 8380417` THENL + [(* ACCEPT branch: val(r8val) < 8380417, so JAE not taken; store + ADD + JMP *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [47] THEN + SUBGOAL_THEN `read RIP s47 = word(pc + 233):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + is_eq(concl th) && can (find_term ((=) `read RIP s47`)) (concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(MESON[] `~p ==> (q = (if p then (a:int64) else b) ==> q = b)`) THEN + (fun (asl, g) -> + let t32 = `:(32)word` in + let rec find_wa t = + if is_comb t then + let f, a = dest_comb t in + if is_comb f && is_const (fst(dest_comb f)) && + fst(dest_const(fst(dest_comb f))) = "word_and" && + type_of t = t32 && is_comb a && is_const(fst(dest_comb a)) && + fst(dest_const(fst(dest_comb a))) = "word" && + (try dest_small_numeral(snd(dest_comb a)) = 8388607 with _ -> false) + then Some t + else match find_wa f with Some r -> Some r | None -> find_wa a + else None in + match find_wa g with + | Some t -> + ABBREV_TAC (mk_eq(mk_var("zval", t32), t)) (asl, g) + | None -> failwith "zval abbrev: no match") THEN + REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(zval:(32)word) < 4294967296` ASSUME_TAC THENL + [MP_TAC(ISPEC `zval:(32)word` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `val(zval:(32)word) MOD 18446744073709551616 MOD 4294967296 = val zval` + SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `x < 4294967296 ==> x < 18446744073709551616`]; ALL_TAC] THEN + SUBGOAL_THEN `r8val:int64 = word_zx(zval:(32)word)` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in + is_eq c && fst(dest_eq c) = `r8val:int64`)) THEN + FIRST_X_ASSUM(SUBST1_TAC o SYM o check (fun th -> + let c = concl th in + is_eq c && snd(dest_eq c) = `zval:(32)word`)) THEN + DISCH_THEN ACCEPT_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(r8val:int64) = val(zval:(32)word)` ASSUME_TAC THENL + [UNDISCH_TAC `r8val:int64 = word_zx(zval:(32)word)` THEN + DISCH_THEN SUBST1_TAC THEN MATCH_MP_TAC VAL_WORD_ZX THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC; ALL_TAC] THEN + COND_CASES_TAC THENL + [UNDISCH_TAC `&8380417:int <= &(val(zval:(32)word))` THEN + UNDISCH_TAC `val(r8val:int64) = val(zval:(32)word)` THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LE; GSYM INT_OF_NUM_LT; + GSYM INT_OF_NUM_EQ] THEN INT_ARITH_TAC; + INT_ARITH_TAC]; ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s47 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [48] THEN + (* Convert the MOV store's bytes32 equation at s48 into a bytes(_,4) + equation, which VSTEPS can propagate through s49 (ADD) and s50 (JMP). *) + SUBGOAL_THEN + `read(memory :> bytes(word_add res (word(4 * val(word curlen:int64))),4)) + s48 = val(r8val:int64)` + ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o MATCH_MP BYTES32_TO_BYTES o check (fun th -> + let c = concl th in is_eq c && + can (find_term ((=) `bytes32`)) c && + can (find_term ((=) `s48:x86state`)) c)) THEN + FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in is_eq c && + can (find_term ((=) `r8val:int64`)) c && + fst(dest_eq c) = `r8val:int64`)) THEN + DISCH_THEN(fun th -> + REWRITE_TAC[th; VAL_WORD_ZX_GEN; DIMINDEX_32; DIMINDEX_64]) THEN + CONV_TAC NUM_REDUCE_CONV THEN + W(fun (_, g) -> + let t32 = `:(32)word` in + let rec find_wa t = + if is_comb t then + let f, a = dest_comb t in + if is_comb f && is_const (fst(dest_comb f)) && + fst(dest_const(fst(dest_comb f))) = "word_and" && + type_of t = t32 && is_comb a && is_const(fst(dest_comb a)) && + fst(dest_const(fst(dest_comb a))) = "word" && + (try dest_small_numeral(snd(dest_comb a)) = 8388607 + with _ -> false) + then Some t + else match find_wa f with Some r -> Some r | None -> find_wa a + else None in + match find_wa g with + | Some t -> + MP_TAC(ISPEC t VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV + | None -> failwith "VAL_BOUND search") THEN + DISCH_TAC THEN + ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `x < 4294967296 ==> x < 18446744073709551616`]; + ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [49;50] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8*N + i + 1) (inlist:(24 word)list)) = + APPEND curlist (REJ_SAMPLE(SUB_LIST(8*N + i, 1) inlist))` + ASSUME_TAC THENL + [SUBGOAL_THEN `8 * N + i + 1 = (8 * N + i) + 1` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + i`; `1:num`; `0:num`] + SUB_LIST_SPLIT) THEN + REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[REJ_SAMPLE_APPEND] THEN + ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `8 * N + i < 280` ASSUME_TAC THENL + [UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; ALL_TAC] THEN + (* ACCEPT byte-bridge: apply ACCEPT_REJ_SAMPLE_SINGLETON with the precise + hypotheses, gathered via narrow FIRST_X_ASSUM picks, to avoid the slow + ASM_REWRITE_TAC across the 280+ assumption list. *) + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(8*N + i, 1) (inlist:(24 word)list)) = + [word(val(r8val:int64)):int32]` + ASSUME_TAC THENL + [(* Normalize `1 * val(word(24*N+3*i))` → `3*(8*N+i)` so the r8val shape matches. *) + SUBGOAL_THEN `1 * val(word (24 * N + 3 * i):int64) = 3 * (8 * N + i) /\ + 1 * val(word (24 * N + 3 * i):int64) + 2 = 3 * (8 * N + i) + 2` + STRIP_ASSUME_TAC THENL + [REWRITE_TAC[MULT_CLAUSES; VAL_WORD; DIMINDEX_64] THEN + SUBGOAL_THEN `(24 * N + 3 * i) MOD 2 EXP 64 = 24 * N + 3 * i` + SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + ARITH_TAC; + ARITH_TAC]; + ALL_TAC] THEN + (* Fetch the 7 hypotheses for ACCEPT_REJ_SAMPLE_SINGLETON and feed them + directly, without ASM_REWRITE. *) + MP_TAC(SPECL [`inlist:(24 word)list`; `buf:int64`; `s39:x86state`; + `s40:x86state`; `r8val:int64`; `N:num`; `i:num`] + ACCEPT_REJ_SAMPLE_SINGLETON) THEN + ANTS_TAC THENL + [CONV_TAC NUM_REDUCE_CONV THEN + REPEAT CONJ_TAC THENL + [(* 1: LENGTH inlist = 280 *) FIRST_ASSUM ACCEPT_TAC; + (* 2: 8*N+i < 280 *) + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* 3: 3*(8*N+i)+3 <= 840 *) + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* 4: mem read s39 *) FIRST_ASSUM ACCEPT_TAC; + (* 5: mem read s40 *) FIRST_ASSUM ACCEPT_TAC; + (* 6: val r8val < 8380417 *) FIRST_ASSUM ACCEPT_TAC; + (* 7: r8val = word_zx(...(word 3*(8*N+i))...) — discharge via MP_TAC + on the r8val equation from asl (which uses `1*val(word(24*N+3*i))`) + and then ASM_REWRITE_TAC[] using the arith normalization. *) + FIRST_ASSUM(MP_TAC o check (fun th -> + let c = concl th in is_eq c && fst(dest_eq c) = `r8val:int64`)) THEN + ASM_REWRITE_TAC[]]; + DISCH_THEN(fun th -> REWRITE_TAC[CONJUNCT2 th])]; + ALL_TAC] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8*N + i + 1) (inlist:(24 word)list)) = + APPEND curlist [word(val(r8val:int64)):int32]` + ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `i + 1 < K` THENL + [ ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH] THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT CONJ_TAC THENL + [(* RAX: word_zx(word_add(word_zx(word curlen))(word 1)) = word(curlen+1) *) + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX: word_zx(word_add(word_zx(word(24*N+3*i)))(word 3)) = word(24*N+3*(i+1)) *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* Memory: bytes(res, 4*(curlen+1)) = num_of_wordlist (APPEND curlist [...]) *) + SUBGOAL_THEN `4 * (curlen + 1) = 4 * curlen + 4` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + (* Fold the RHS's big expanded word back to r8val *) + FIRST_ASSUM(fun th -> let c = concl th in + if is_eq c && fst(dest_eq c) = `r8val:int64` + then GEN_REWRITE_TAC (RAND_CONV o DEPTH_CONV) [SYM th] + else failwith "r8val fold") THEN + MP_TAC(ISPECL [`memory:(x86state,int64->byte)component`; `res:int64`; + `s50:x86state`; `curlist:int32 list`; + `[word(val(r8val:int64)):int32]`; `4 * curlen`; `4`] + BYTES_EQ_NUM_OF_WORDLIST_APPEND) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + (* Single-element write: num_of_wordlist [word(val r8val):int32] = + val(word(val r8val)) = val r8val (since < 2^32), and the bytes(_,4) + equation is propagated from s48 through VSTEPS 49-50. *) + REWRITE_TAC[num_of_wordlist; MULT_CLAUSES; ADD_CLAUSES] THEN + SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[]]; + (* MAYCHANGE closure *) + DISCHARGE_MEMSAFE_ASM_TAC; + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* ACCEPT i+1=K branch: step through CMP EAX,256; JAE (pc+242) to reach + pc+242, using the strengthened lemma's WOP disjunct *) + SUBGOAL_THEN `i + 1 = K` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < K)` THEN UNDISCH_TAC `i < K` THEN ARITH_TAC; + ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [51;52] THEN + (* Split on WOP disjunct: count-exit vs offset-exit *) + FIRST_ASSUM(DISJ_CASES_TAC o check (fun th -> is_disj (concl th))) THENL + [(* count-exit: 256 <= LENGTH(REJ_SAMPLE ...), so curlen+1 = 256. + The ACCEPT branch has REJ_SAMPLE(0, 8*N+i+1) = APPEND curlist [r8val], + and with i+1=K we get length = curlen+1, so 256 <= curlen+1. + Combined with curlen < 256: curlen+1 = 256. *) + SUBGOAL_THEN `curlen + 1 = 256` ASSUME_TAC THENL + [UNDISCH_TAC `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + UNDISCH_TAC `i + 1 = K` THEN DISCH_THEN(fun th -> REWRITE_TAC[SYM th]) THEN + UNDISCH_TAC `REJ_SAMPLE (SUB_LIST (0,8 * N + i + 1) (inlist:(24 word)list)) = + APPEND curlist [word(val(r8val:int64)):int32]` THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[LENGTH_APPEND; LENGTH] THEN + UNDISCH_TAC `LENGTH (curlist:int32 list) = curlen` THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `read RIP s52 = word(pc + 242):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s52`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + SUBGOAL_THEN `val(word_add (word_zx (word curlen:int64):(32)word) (word 1:(32)word)) = curlen + 1` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`]; + ALL_TAC] THEN + ASM_REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s52 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + ENSURES_FINAL_STATE_TAC THEN + REWRITE_TAC[LET_DEF; LET_END_DEF] THEN + REPEAT CONJ_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH] THENL + [(* RAX *) + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX *) + FIRST_X_ASSUM (SUBST1_TAC o SYM o check (fun th -> concl th = `i + 1 = K`)) THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* Memory: bytes(res, 4*(curlen+1)) = num_of_wordlist (APPEND curlist [...]) *) + SUBGOAL_THEN `curlen + SUC 0 = curlen + 1` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `4 * (curlen + 1) = 4 * curlen + 4` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + FIRST_ASSUM(fun th -> let c = concl th in + if is_eq c && fst(dest_eq c) = `r8val:int64` + then GEN_REWRITE_TAC (RAND_CONV o DEPTH_CONV) [SYM th] + else failwith "r8val fold") THEN + MP_TAC(ISPECL [`memory:(x86state,int64->byte)component`; `res:int64`; + `s52:x86state`; `curlist:int32 list`; + `[word(val(r8val:int64)):int32]`; `4 * curlen`; `4`] + BYTES_EQ_NUM_OF_WORDLIST_APPEND) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + REWRITE_TAC[num_of_wordlist; MULT_CLAUSES; ADD_CLAUSES] THEN + SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[]]; + (* MAYCHANGE closure *) + DISCHARGE_MEMSAFE_ASM_TAC; + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* offset-exit: 837 < 24*N+3*K, sub-split on 256 <= curlen+1 *) + ASM_CASES_TAC `256 <= curlen + 1` THENL + [(* Case A: curlen+1 = 256 (same output as count-exit). *) + SUBGOAL_THEN `curlen + 1 = 256` ASSUME_TAC THENL + [UNDISCH_TAC `256 <= curlen + 1` THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `read RIP s52 = word(pc + 242):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s52`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + SUBGOAL_THEN `val(word_add (word_zx (word curlen:int64):(32)word) (word 1:(32)word)) = curlen + 1` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`]; + ALL_TAC] THEN + ASM_REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s52 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + ENSURES_FINAL_STATE_TAC THEN + REWRITE_TAC[LET_DEF; LET_END_DEF] THEN + REPEAT CONJ_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH] THENL + [(* RAX *) + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX *) + FIRST_X_ASSUM (SUBST1_TAC o SYM o check (fun th -> concl th = `i + 1 = K`)) THEN + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* Memory *) + SUBGOAL_THEN `curlen + SUC 0 = curlen + 1` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `4 * (curlen + 1) = 4 * curlen + 4` SUBST1_TAC THENL + [ARITH_TAC; ALL_TAC] THEN + FIRST_ASSUM(fun th -> let c = concl th in + if is_eq c && fst(dest_eq c) = `r8val:int64` + then GEN_REWRITE_TAC (RAND_CONV o DEPTH_CONV) [SYM th] + else failwith "r8val fold") THEN + MP_TAC(ISPECL [`memory:(x86state,int64->byte)component`; `res:int64`; + `s52:x86state`; `curlist:int32 list`; + `[word(val(r8val:int64)):int32]`; `4 * curlen`; `4`] + BYTES_EQ_NUM_OF_WORDLIST_APPEND) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + REWRITE_TAC[num_of_wordlist; MULT_CLAUSES; ADD_CLAUSES] THEN + SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[]]; + (* MAYCHANGE *) + DISCHARGE_MEMSAFE_ASM_TAC; + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* Case B: curlen+1 < 256 *) + (* Case B: curlen+1 < 256. Step through CMP ECX,837; JA at s52, + then X86_STEPS [53;54] after `wa` abbreviation, then close. *) + SUBGOAL_THEN `read RIP s52 = word(pc + 188):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s52`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + SUBGOAL_THEN `val(word_add (word_zx (word curlen:int64):(32)word) (word 1:(32)word)) = curlen + 1` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`]; + ALL_TAC] THEN + ASM_REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen + 1 < 4294967296`; + ARITH_RULE `256 < 4294967296`] THEN + UNDISCH_TAC `~(256 <= curlen + 1)` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LE] THEN INT_ARITH_TAC; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s52 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + (* Abbreviate word_and sub-expression as `wa` to preserve r8val def *) + (fun (asl,g) -> + let rec findit = function + | [] -> failwith "no r8val def" + | (_, th) :: rest -> + let c = concl th in + if is_eq c && (try fst(dest_var(fst(dest_eq c))) = "r8val" with _ -> false) then + let rhs = snd(dest_eq c) in + (try let _, inner = dest_comb rhs in + ABBREV_TAC (mk_eq(mk_var("wa", type_of inner), inner)) (asl,g) + with _ -> findit rest) + else findit rest + in findit asl) THEN + SUBGOAL_THEN `val(r8val:int64) = val(wa:(32)word)` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in + is_eq c && (try fst(dest_var(fst(dest_eq c))) = "r8val" with _ -> false))) THEN + DISCH_THEN SUBST1_TAC THEN + MATCH_MP_TAC VAL_WORD_ZX THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `word(val(r8val:int64)):(32)word = wa` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN CONV_TAC WORD_BLAST; + ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [53;54] THEN + SUBGOAL_THEN `read RIP s54 = word(pc + 242):int64` ASSUME_TAC THENL + [MP_TAC(SPECL [`N:num`; `i:num`] VAL_RCX_ADD3_ZX) THEN + ANTS_TAC THENL [FIRST_ASSUM ACCEPT_TAC; ALL_TAC] THEN + DISCH_TAC THEN + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s54`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[VAL_WORD_SUB_CASES; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `837 <= 24 * N + 3 * i + 3` (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `~((24 * N + 3 * i + 3) - 837 = 0)` + (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + REWRITE_TAC[DE_MORGAN_THM; NOT_CLAUSES] THEN + MP_TAC(SPECL [`837:num`; `24 * N + 3 * i + 3`] INT_OF_NUM_SUB) THEN + ANTS_TAC THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN(fun th -> REWRITE_TAC[SYM th]) THEN INT_ARITH_TAC; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read RIP s54 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + (* Pre-establish augmented memory invariant via MEMORY_CONJUNCT_CLOSURE *) + SUBGOAL_THEN `val(word curlen:int64) = curlen` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th])) THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN + `read (memory :> bytes (res, 4 * (curlen + 1))) s54 = + num_of_wordlist (APPEND curlist [word(val(r8val:int64)):int32])` + ASSUME_TAC THENL + [SUBGOAL_THEN `val(word(val(r8val:int64)):int32) = val r8val` + ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(r8val:int64) < 8380417` THEN ARITH_TAC; + ALL_TAC] THEN + MATCH_MP_TAC MEMORY_CONJUNCT_CLOSURE THEN + REPEAT CONJ_TAC THENL + [FIRST_ASSUM ACCEPT_TAC; + FIRST_ASSUM ACCEPT_TAC; + FIRST_ASSUM ACCEPT_TAC; + ONCE_REWRITE_TAC[ASSUME `val(word(val(r8val:int64)):int32) = val r8val`] THEN + FIRST_ASSUM ACCEPT_TAC]; + ALL_TAC] THEN + UNDISCH_THEN `r8val:int64 = word_zx(wa:(32)word)` + (fun th -> RULE_ASSUM_TAC(REWRITE_RULE[th]) THEN ASSUME_TAC th) THEN + ENSURES_FINAL_STATE_TAC THEN + REWRITE_TAC[LET_DEF; LET_END_DEF] THEN + REPEAT CONJ_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; LENGTH; + ARITH_RULE `curlen + SUC 0 = curlen + 1`] THENL + [(* RAX *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + (* RCX *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; VAL_WORD_ADD; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + (* MAYCHANGE *) + DISCHARGE_MEMSAFE_ASM_TAC; + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]]]]; + (* REJECT branch: val(r8val) >= 8380417, JAE taken to pc+181 *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [47] THEN + SUBGOAL_THEN `read RIP s47 = word(pc + 181):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + is_eq(concl th) && can (find_term ((=) `read RIP s47`)) (concl th) && + can (find_term is_cond) (concl th))) THEN + MATCH_MP_TAC(MESON[] `p ==> (q = (if p then (a:int64) else b) ==> q = a)`) THEN + SUBGOAL_THEN `8380417 <= val(r8val:int64)` ASSUME_TAC THENL + [UNDISCH_TAC `~(val(r8val:int64) < 8380417)` THEN ARITH_TAC; ALL_TAC] THEN + (fun (asl, g) -> + let t32 = `:(32)word` in + let rec find_wa t = + if is_comb t then + let f, a = dest_comb t in + if is_comb f && is_const (fst(dest_comb f)) && + fst(dest_const(fst(dest_comb f))) = "word_and" && + type_of t = t32 && is_comb a && is_const(fst(dest_comb a)) && + fst(dest_const(fst(dest_comb a))) = "word" && + (try dest_small_numeral(snd(dest_comb a)) = 8388607 with _ -> false) + then Some t + else match find_wa f with Some r -> Some r | None -> find_wa a + else None in + match find_wa g with + | Some t -> + ABBREV_TAC (mk_eq(mk_var("zval", t32), t)) (asl, g) + | None -> failwith "zval abbrev: no match") THEN + REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(zval:(32)word) < 4294967296` ASSUME_TAC THENL + [MP_TAC(ISPEC `zval:(32)word` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `val(zval:(32)word) MOD 18446744073709551616 MOD 4294967296 = val zval` + SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `x < 4294967296 ==> x < 18446744073709551616`]; ALL_TAC] THEN + SUBGOAL_THEN `r8val:int64 = word_zx(zval:(32)word)` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + let c = concl th in + is_eq c && fst(dest_eq c) = `r8val:int64`)) THEN + FIRST_X_ASSUM(SUBST1_TAC o SYM o check (fun th -> + let c = concl th in + is_eq c && snd(dest_eq c) = `zval:(32)word`)) THEN + DISCH_THEN ACCEPT_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(r8val:int64) = val(zval:(32)word)` ASSUME_TAC THENL + [UNDISCH_TAC `r8val:int64 = word_zx(zval:(32)word)` THEN + DISCH_THEN SUBST1_TAC THEN MATCH_MP_TAC VAL_WORD_ZX THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC; ALL_TAC] THEN + COND_CASES_TAC THENL + [REFL_TAC; + UNDISCH_TAC `~(&8380417:int <= &(val(zval:(32)word)))` THEN + UNDISCH_TAC `val(r8val:int64) = val(zval:(32)word)` THEN + UNDISCH_TAC `8380417 <= val(r8val:int64)` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LE; GSYM INT_OF_NUM_EQ] THEN + INT_ARITH_TAC]; ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read RIP s47 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8*N + i + 1) (inlist:(24 word)list)) = + APPEND curlist (REJ_SAMPLE(SUB_LIST(8*N + i, 1) inlist))` + ASSUME_TAC THENL + [SUBGOAL_THEN `8 * N + i + 1 = (8 * N + i) + 1` SUBST1_TAC THENL [ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + i`; `1:num`; `0:num`] SUB_LIST_SPLIT) THEN + REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[REJ_SAMPLE_APPEND] THEN + ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `8 * N + i < 280` ASSUME_TAC THENL + [UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; ALL_TAC] THEN + (* Pivot lemma: val r8val equals the 23 low bits of the list element. + Use the extracted PIVOT_VAL_EQ top-level lemma for fast application. *) + SUBGOAL_THEN `1 * val(word (24 * N + 3 * i):int64) = 3 * (8 * N + i) /\ + 1 * val(word (24 * N + 3 * i):int64) + 2 = 3 * (8 * N + i) + 2` + STRIP_ASSUME_TAC THENL + [REWRITE_TAC[MULT_CLAUSES; VAL_WORD; DIMINDEX_64] THEN + SUBGOAL_THEN `(24 * N + 3 * i) MOD 2 EXP 64 = 24 * N + 3 * i` + SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + ARITH_TAC; + ARITH_TAC]; + ALL_TAC] THEN + SUBGOAL_THEN + `val(r8val:int64) = val(EL (8*N+i) (inlist:(24 word)list)) MOD 2 EXP 23` + ASSUME_TAC THENL + [MP_TAC(SPECL [`inlist:(24 word)list`; `buf:int64`; `s39:x86state`; + `s40:x86state`; `r8val:int64`; `N:num`; `i:num`] + PIVOT_VAL_EQ) THEN + ASM_REWRITE_TAC[ARITH_RULE `3 * 280 = 840`] THEN + ANTS_TAC THENL + [UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + DISCH_THEN ACCEPT_TAC]; + ALL_TAC] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(8 * N + i, 1) (inlist:(24 word)list)) = []` + ASSUME_TAC THENL + [REWRITE_TAC[SUB_LIST_1] THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[REJ_SAMPLE; MAP; FILTER] THEN + REWRITE_TAC[VAL_MOD_23_EQ_AND] THEN + COND_CASES_TAC THENL + [SUBGOAL_THEN + `val (word_and (word_zx (EL (8 * N + i) (inlist:(24 word)list)):int32) + (word 8388607):int32) = + val(EL (8 * N + i) (inlist:(24 word)list)) MOD 2 EXP 23` + SUBST_ALL_TAC THENL + [REWRITE_TAC[GSYM VAL_MOD_23_EQ_AND; VAL_WORD; DIMINDEX_32] THEN + MATCH_MP_TAC MOD_LT THEN + MP_TAC(ISPEC `EL (8 * N + i) (inlist:(24 word)list)` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_24] THEN ARITH_TAC; + ALL_TAC] THEN + UNDISCH_TAC `~(val(r8val:int64) < 8380417)` THEN + ASM_REWRITE_TAC[] THEN ARITH_TAC; + REFL_TAC]; ALL_TAC] THEN + SUBGOAL_THEN + `REJ_SAMPLE(SUB_LIST(0, 8 * N + i + 1) (inlist:(24 word)list)) = curlist` + ASSUME_TAC THENL + [ASM_REWRITE_TAC[APPEND_NIL]; ALL_TAC] THEN + ASM_CASES_TAC `i + 1 < K` THENL + [ ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[APPEND_NIL] THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THENL + [(* RCX: word_zx(word_add(word_zx(word(24*N+3*i)))(word 3)) = word(24*N+3*(i+1)) *) + ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN ARITH_TAC; + (* MAYCHANGE closure *) + DISCHARGE_MEMSAFE_ASM_TAC; + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]; + (* i + 1 = K branch of REJECT — fully closed via WOP offset-exit. + Mirrors Case B ACCEPT i+1=K: JA not taken on curlen<256, then + CMP RCX,837 / JA taken to pc+242 using VAL_RCX_ADD3_ZX. *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [48] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [49] THEN + SUBGOAL_THEN `read RIP s49 = word(pc + 188):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s49`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + REWRITE_TAC[INT_VAL_WORD_SUB_CASES; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen MOD 18446744073709551616 MOD 4294967296 = curlen` + SUBST1_TAC THENL + [ASM_SIMP_TAC[MOD_LT; + ARITH_RULE `curlen < 256 ==> curlen < 18446744073709551616`; + ARITH_RULE `curlen < 256 ==> curlen < 4294967296`]; + ALL_TAC] THEN + SUBGOAL_THEN `~(&256:int <= &curlen)` ASSUME_TAC THENL + [UNDISCH_TAC `curlen < 256` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_LE] THEN INT_ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `curlen < 256` THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT] THEN + INT_ARITH_TAC; ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read RIP s49 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [50] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC [51] THEN + FIRST_ASSUM(DISJ_CASES_TAC o check (fun th -> is_disj (concl th))) THENL + [SUBGOAL_THEN `i + 1 = K` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < K)` THEN UNDISCH_TAC `i < K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `F` MP_TAC THENL + [UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + UNDISCH_TAC `REJ_SAMPLE (SUB_LIST (0,8 * N + i + 1) inlist) = curlist` THEN + UNDISCH_TAC `i + 1 = K` THEN + DISCH_THEN(SUBST1_TAC o SYM) THEN + DISCH_THEN SUBST1_TAC THEN + UNDISCH_TAC `LENGTH (curlist:int32 list) = curlen` THEN + UNDISCH_TAC `curlen < 256` THEN ARITH_TAC; + MESON_TAC[]]; + SUBGOAL_THEN `i + 1 = K` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < K)` THEN UNDISCH_TAC `i < K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `read RIP s51 = word(pc + 242):int64` ASSUME_TAC THENL + [MP_TAC(SPECL [`N:num`; `i:num`] VAL_RCX_ADD3_ZX) THEN + ANTS_TAC THENL [FIRST_ASSUM ACCEPT_TAC; ALL_TAC] THEN + DISCH_TAC THEN + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s51`)) c && + can (find_term is_cond) c)) THEN + MATCH_MP_TAC (TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[VAL_WORD_SUB_CASES; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `837 <= 24 * N + 3 * i + 3` (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `~((24 * N + 3 * i + 3) - 837 = 0)` + (fun th -> REWRITE_TAC[th]) THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + REWRITE_TAC[DE_MORGAN_THM; NOT_CLAUSES] THEN + MP_TAC(SPECL [`837:num`; `24 * N + 3 * i + 3`] INT_OF_NUM_SUB) THEN + ANTS_TAC THENL + [UNDISCH_TAC `837 < 24 * N + 3 * K` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN(fun th -> REWRITE_TAC[SYM th]) THEN INT_ARITH_TAC; + ALL_TAC] THEN + DISCARD_MATCHING_ASSUMPTIONS + [`read RIP s51 = (if p then (a:int64) else b)`] THEN + (* MEMSAFE: keep events for DISCHARGE_MEMSAFE *) ALL_TAC THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[APPEND_NIL] THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THENL + [ONCE_REWRITE_TAC[GSYM VAL_EQ] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_ZX_GEN; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[MOD_LT] THEN + UNDISCH_TAC `24 * N + 3 * i <= 837` THEN + UNDISCH_TAC `i + 1 = K` THEN ARITH_TAC; + DISCHARGE_MEMSAFE_ASM_TAC; + REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC]]]]);; + + +let MLDSA_REJ_UNIFORM_MEMSAFE = prove + (`!res buf table (inlist:(24 word)list) e pc. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, 243) (res, 1024) /\ + nonoverlapping (word pc, 243) (buf, 840) /\ + nonoverlapping (word pc, 243) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) + ==> ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST mldsa_rej_uniform_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read events s = e) + (\s. read RIP s = word(pc + 242) /\ + (exists e2. + read events s = APPEND e2 e /\ + memaccess_inbounds e2 + [buf,840; table,2048] + [res,1024])) + (MAYCHANGE [RIP; RAX; RCX; R8; R9; R10] ,, + MAYCHANGE [ZMM0; ZMM1; ZMM2; ZMM3; ZMM4; + ZMM5; ZMM6; ZMM7; ZMM8; ZMM9; ZMM10; ZMM11; + ZMM12; ZMM13; ZMM14; ZMM15] ,, + MAYCHANGE SOME_FLAGS ,, MAYCHANGE [events] ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + MAP_EVERY X_GEN_TAC + [`res:int64`; `buf:int64`; `table:int64`; + `inlist:(24 word)list`; `e:(uarch_event)list`; `pc:num`] THEN + REWRITE_TAC[C_ARGUMENTS; C_RETURN; SOME_FLAGS; NONOVERLAPPING_CLAUSES] THEN + STRIP_TAC THEN + GHOST_INTRO_TAC `stackpointer:int64` `read RSP` THEN + + SUBGOAL_THEN + `?i. 832 < 24 * (i + 1) \/ 248 < LENGTH(REJ_SAMPLE(SUB_LIST(0,8 * i) inlist))` + MP_TAC THENL + [EXISTS_TAC `280:num` THEN ARITH_TAC; + GEN_REWRITE_TAC LAND_CONV [num_WOP]] THEN + DISCH_THEN(X_CHOOSE_THEN `N:num` (CONJUNCTS_THEN2 ASSUME_TAC MP_TAC)) THEN + DISCH_THEN(fun th -> ASSUME_TAC(REWRITE_RULE[DE_MORGAN_THM; NOT_LT] th)) THEN + SUBGOAL_THEN `~(N = 0)` ASSUME_TAC THENL + [DISCH_TAC THEN FIRST_X_ASSUM(MP_TAC o check (is_disj o concl)) THEN + ASM_REWRITE_TAC[MULT_CLAUSES; ADD_CLAUSES; SUB_LIST_CLAUSES] THEN + REWRITE_TAC[REJ_SAMPLE_EMPTY; LENGTH] THEN ARITH_TAC; + ALL_TAC] THEN + + ENSURES_WHILE_UP2_TAC `N:num` `pc + 104` `pc + 181` + `\i s. + read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + (let outlist = REJ_SAMPLE(SUB_LIST(0,8*i) inlist) in + let outlen = LENGTH outlist in + read RAX s = word outlen /\ + read RCX s = word(24*i) /\ + read(memory :> bytes(res,4*outlen)) s = num_of_wordlist outlist) /\ + (exists e_acc. + read events s = APPEND e_acc e /\ + memaccess_inbounds e_acc + [buf,840; table,2048] [res,1024])` THEN + ASM_REWRITE_TAC[LT_REFL] THEN REPEAT CONJ_TAC THENL + + [(* Phase 1: Pre-loop *) + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC (1--17) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[MULT_CLAUSES; ADD_CLAUSES; SUB_LIST_CLAUSES; REJ_SAMPLE_EMPTY; + LENGTH; num_of_wordlist] THEN + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[READ_COMPONENT_COMPOSE; READ_MEMORY_BYTES_TRIVIAL] THEN + CONV_TAC WORD_REDUCE_CONV THEN + EXISTS_TAC `[]:(uarch_event)list` THEN + REWRITE_TAC[APPEND; memaccess_inbounds; ALL]; + + X_GEN_TAC `i:num` THEN STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + + ABBREV_TAC `curlist = REJ_SAMPLE (SUB_LIST(0,8*i) inlist)` THEN + ABBREV_TAC `curlen = LENGTH(curlist:int32 list)` THEN + CONV_TAC(RATOR_CONV(LAND_CONV(TOP_DEPTH_CONV let_CONV))) THEN + ASM_REWRITE_TAC[] THEN + + (* (a) Get bounds from WOP at i *) + FIRST_ASSUM(MP_TAC o C MATCH_MP (ASSUME `i:num < N`) o + check (fun th -> is_forall(concl th))) THEN + ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + + SUBGOAL_THEN `curlen <= 248` ASSUME_TAC THENL + [ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `24 * i <= 808` ASSUME_TAC THENL + [UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; ALL_TAC] THEN + + ENSURES_INIT_TAC "s0" THEN + STRIP_EXISTS_ASSUM_TAC THEN + + (* (b) Instructions 18-19: CMP eax,0xF8; JA — not taken. + For MEMSAFE, derive the COND simplification rewrite and apply it to + all assumptions (including events chain) before discarding the + COND-laden RIP hypothesis. Pin types via explicit annotations to + avoid type_invention pollution that breaks downstream MATCH_MP_TAC. *) + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [18;19] THEN + SUBGOAL_THEN `read RIP s19 = (word(pc + 111):int64)` ASSUME_TAC THENL + [RESOLVE_JA_CURLEN_TAC; ALL_TAC] THEN + (* Derive COND_s19 = word(pc+111) and rewrite events. *) + (fun (asl,w) -> + try + let cond_th = snd(List.find (fun (_,th) -> + let c = concl th in + is_eq c && + can (find_term ((=) `read RIP s19`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c)) asl) in + let clean_th = snd(List.find (fun (_,th) -> + concl th = `read RIP s19 = (word(pc + 111):int64)`) asl) in + let cond_eq_clean = TRANS (SYM cond_th) clean_th in + RULE_ASSUM_TAC (REWRITE_RULE [cond_eq_clean]) (asl,w) + with _ -> ALL_TAC (asl,w)) THEN + TRY (FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s19`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c)))) THEN + + (* (c) Instructions 20-21: CMP ecx,0x328; JA — not taken *) + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [20;21] THEN + SUBGOAL_THEN `read RIP s21 = (word(pc + 119):int64)` ASSUME_TAC THENL + [RESOLVE_JA_OFFSET_TAC; ALL_TAC] THEN + (fun (asl,w) -> + try + let cond_th = snd(List.find (fun (_,th) -> + let c = concl th in + is_eq c && + can (find_term ((=) `read RIP s21`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c)) asl) in + let clean_th = snd(List.find (fun (_,th) -> + concl th = `read RIP s21 = (word(pc + 119):int64)`) asl) in + let cond_eq_clean = TRANS (SYM cond_th) clean_th in + RULE_ASSUM_TAC (REWRITE_RULE [cond_eq_clean]) (asl,w) + with _ -> ALL_TAC (asl,w)) THEN + TRY (FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s21`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c)))) THEN + + (* (d) SIMD body: all verbose to preserve VMOVDQU→VPSHUFB→VPAND chain *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (22--29) THEN + + (* Abbreviate the 8 masked coefficients from YMM3 after VPAND *) + (* Semantic bridge: use POPCNT_VMOVMSKPS_LEMMA to establish + R9 = word(LENGTH(FILTER)) for the 8 masked dword lanes. + The YMM3 at s26 = word_and(read YMM3 s25)(mask_broadcast). + After ASM_REWRITE, the read R9 s29 expands to the popcount + of the sign-bit mask, and the LEMMA matches directly. *) + SUBGOAL_THEN + `read R9 s29:int64 = + word(LENGTH(FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s26:int256) (0,32):int32; + word_subword (read YMM3 s26) (32,32); + word_subword (read YMM3 s26) (64,32); + word_subword (read YMM3 s26) (96,32); + word_subword (read YMM3 s26) (128,32); + word_subword (read YMM3 s26) (160,32); + word_subword (read YMM3 s26) (192,32); + word_subword (read YMM3 s26) (224,32)]))` + MP_TAC THENL + [ASM_REWRITE_TAC[] THEN + CONV_TAC(LAND_CONV(REWR_CONV word_zx)) THEN + REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + AP_TERM_TAC THEN + REPEAT(FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + can (find_term ((=) `s25:x86state`)) (concl th) || + can (find_term ((=) `s26:x86state`)) (concl th) || + can (find_term ((=) `s27:x86state`)) (concl th) || + can (find_term ((=) `s28:x86state`)) (concl th) || + can (find_term ((=) `s29:x86state`)) (concl th)))) THEN + SIMP_TAC[WORD_ZX_ZX; DIMINDEX_32; DIMINDEX_64; + ARITH_RULE `32 <= 64`] THEN + SIMP_TAC[WORD_POPCOUNT_WORD_ZX; DIMINDEX_8; DIMINDEX_32; + ARITH_RULE `8 <= 32`] THEN + REWRITE_TAC[VMOVMSKPS_POPCOUNT_EQ; BIT_NESTED_JOIN_8] THEN + REWRITE_TAC[POPCNT_VMOVMSKPS_LEMMA] THEN + MATCH_MP_TAC MOD_LT THEN + TRANS_TAC LTE_TRANS `9` THEN CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `n <= 8 ==> n < 9`) THEN + W(MP_TAC o PART_MATCH lhand LENGTH_FILTER o lhand o snd) THEN + REWRITE_TAC[LENGTH] THEN ARITH_TAC; + ARITH_TAC]; + DISCARD_MATCHING_ASSUMPTIONS [`read R9 s = x`] THEN STRIP_TAC] THEN + + (* SIMD↔spec: prove BEFORE DISCARD while YMM3/buffer hyps available. + newlen (the FILTER length over SIMD coefficients) = LENGTH(REJ_SAMPLE(...)) + This is the key semantic bridge: VPERMQ+VPSHUFB+VPAND = spec's MAP+FILTER. + The result is state-independent and survives DISCARD_OLDSTATE_TAC. + Approach: WORD_SIMPLE_SUBWORD_CONV reduces the 256-bit VPSHUFB chain + to clean 8-bit byte extractions from the bytes256 memory read. Then + bytes256 → bytes(32) → MOD 2^192 → num_of_wordlist(SUB_LIST). *) + SUBGOAL_THEN + `FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s26:int256) (0,32):int32; + word_subword (read YMM3 s26) (32,32); + word_subword (read YMM3 s26) (64,32); + word_subword (read YMM3 s26) (96,32); + word_subword (read YMM3 s26) (128,32); + word_subword (read YMM3 s26) (160,32); + word_subword (read YMM3 s26) (192,32); + word_subword (read YMM3 s26) (224,32)] = + REJ_SAMPLE(SUB_LIST(8*i,8) inlist)` + ASSUME_TAC THENL + [REWRITE_TAC[REJ_SAMPLE] THEN + REPEAT(FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + can (find_term ((=) `newlen:num`)) (concl th) || + can (find_term ((=) `R9`)) (concl th)))) THEN + REPEAT(FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + not(can (find_term ((=) `YMM3`)) c && + can (find_term ((=) (mk_var("s26",`:x86state`)))) c) && + not(can (find_term ((=) `inlist:(24 word)list`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "num_of_wordlist" with _ -> false)) c && + can (find_term ((=) (mk_var("s21",`:x86state`)))) c) && + (can (find_term (fun t -> try let s = fst(dest_var t) in + String.length s > 0 && s.[0] = 's' && s <> "stackpointer" + with _ -> false)) c || + can (find_term ((=) `MAYCHANGE`)) c || + can (find_term ((=) `bytes_loaded`)) c)))) THEN + FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `YMM3`)) (concl th) && + can (find_term ((=) (mk_var("s26",`:x86state`)))) (concl th) && + is_eq(concl th) + then GEN_REWRITE_TAC (ONCE_DEPTH_CONV) [th] + else failwith "") THEN + CONV_TAC(TOP_DEPTH_CONV + (REWR_CONV WORD_SUBWORD_AND ORELSEC WORD_SIMPLE_SUBWORD_CONV)) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + SUBGOAL_THEN `1 * val(word(24 * i):int64) = 24 * i` SUBST1_TAC THENL + [REWRITE_TAC[MULT_CLAUSES; VAL_WORD; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + ALL_TAC] THEN + REWRITE_TAC[BYTE_JOIN_ZX; BYTE_JOIN_SUBWORD_ZX; + bytes256; READ_COMPONENT_COMPOSE; asword; through; read] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + SUBGOAL_THEN + `read(memory :> bytes(word_add buf (word(24*i)),24)) s21 = + num_of_wordlist(SUB_LIST(8*i,8) (inlist:(24 word)list))` + ASSUME_TAC THENL + [REWRITE_TAC[NUM_OF_WORDLIST_SUB_LIST; DIMINDEX_24] THEN + CONV_TAC NUM_REDUCE_CONV THEN + FIRST_ASSUM(fun th -> + if is_eq(concl th) && + can (find_term (fun t -> + try fst(dest_const t) = "num_of_wordlist" with _ -> false)) (concl th) && + not(can (find_term (fun t -> + try fst(dest_const t) = "SUB_LIST" with _ -> false)) (concl th)) && + (let s = string_of_term(concl th) in + let n = String.length s in + let rec has840 j = j + 2 < n && + (s.[j] = '8' && s.[j+1] = '4' && s.[j+2] = '0' || has840 (j+1)) in + has840 0) + then GEN_REWRITE_TAC (RAND_CONV o LAND_CONV o LAND_CONV) [GSYM th] + else failwith "") THEN + REWRITE_TAC[GSYM READ_BYTES_DIV; GSYM READ_BYTES_MOD; + ARITH_RULE `8 * (24 * i) = 192 * i`; + ARITH_RULE `8 * 24 = 192`] THEN + REWRITE_TAC[READ_COMPONENT_COMPOSE; READ_BYTES_DIV; READ_BYTES_MOD] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `MIN (840 - 24 * i) 24 = 24` SUBST1_TAC THENL + [REWRITE_TAC[MIN] THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + REWRITE_TAC[ARITH_RULE `24 * 8 * i = 8 * (24 * i)`] THEN + GEN_REWRITE_TAC (RAND_CONV o ONCE_DEPTH_CONV) + [GSYM(NUM_REDUCE_CONV `2 EXP (8 * 24)`)] THEN + REWRITE_TAC[READ_BYTES_DIV; READ_BYTES_MOD] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `MIN (840 - 24 * i) 24 = 24` SUBST1_TAC THENL + [REWRITE_TAC[MIN] THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + REFL_TAC]]; + ALL_TAC] THEN + SUBGOAL_THEN + `read(bytes(word_add buf (word(24*i)),32))(read memory s21) MOD + 2 EXP 192 = + num_of_wordlist(SUB_LIST(8*i,8) (inlist:(24 word)list))` + MP_TAC THENL + [FIRST_X_ASSUM(MP_TAC o REWRITE_RULE[READ_COMPONENT_COMPOSE]) THEN + DISCH_THEN(SUBST1_TAC o SYM) THEN + GEN_REWRITE_TAC (LAND_CONV o ONCE_DEPTH_CONV) + [GSYM(NUM_REDUCE_CONV `8 * 24`)] THEN + REWRITE_TAC[READ_BYTES_MOD] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[MIN; ARITH_RULE `24 <= 32`]; + ALL_TAC] THEN + ABBREV_TAC + `n32 = read(bytes(word_add buf (word(24*i)),32))(read memory s21)` THEN + DISCH_TAC THEN + ASM_REWRITE_TAC[VAL_MOD_23_EQ_AND; COEFF_FROM_BYTES; + EL_NUM_OF_WORDLIST; NUM_OF_WORDLIST_SUB_LIST; + DIMINDEX_24] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + ASM_REWRITE_TAC[] THEN + (* COEFF_BYTE_JOIN_WORD converts byte_join coefficients to word(n MOD 2^256 DIV 2^ofs MOD 2^23). + Use GEN_REWRITE with concrete instances for each of the 8 offsets. *) + (* Combined COEFF + MOD_256_192: byte_join coeffs → word(n32 MOD 2^192 DIV 2^k MOD 2^23) *) + GEN_REWRITE_TAC (LAND_CONV o DEPTH_CONV) + (map (fun k -> + let kterm = mk_small_numeral k in + let coeff_th = CONV_RULE NUM_REDUCE_CONV + (SPECL [`n32:num`; kterm] COEFF_BYTE_JOIN_WORD) in + let mod_th = CONV_RULE NUM_REDUCE_CONV + (SPECL [`n32:num`; kterm] MOD_256_192) in + CONV_RULE NUM_REDUCE_CONV (TRANS coeff_th (AP_TERM `word:num->int32` mod_th))) + [0;24;48;72;96;120;144;168]) THEN + CONV_TAC NUM_REDUCE_CONV THEN REWRITE_TAC[DIV_1] THEN + (* Convert huge 2^192 numeral back to 2 EXP 192 so hypothesis can match *) + REWRITE_TAC[GSYM(NUM_REDUCE_CONV `2 EXP 192`)] THEN + ASM_REWRITE_TAC[] THEN + (* Prove LENGTH(SUB_LIST(8*i,8) inlist) = 8 for REJ_SAMPLE_COEFFS *) + SUBGOAL_THEN `LENGTH(SUB_LIST(8*i,8) (inlist:(24 word)list)) = 8` + ASSUME_TAC THENL + [REWRITE_TAC[LENGTH_SUB_LIST] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC(ARITH_RULE + `8 * i + 8 <= l ==> MIN 8 (l - 8 * i) = 8`) THEN + ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[CONV_RULE NUM_REDUCE_CONV MAP_REJ_COEFFS]; + ALL_TAC] THEN + + (* Derive LENGTH from FILTER equality for the ABBREV *) + FIRST_X_ASSUM(fun th -> + if can (find_term (fun t -> try fst(dest_const t) = "FILTER" with _ -> false)) (concl th) && + can (find_term (fun t -> try fst(dest_const t) = "REJ_SAMPLE" with _ -> false)) (concl th) && + is_eq(concl th) && + not(can (find_term ((=) `LENGTH:(int32 list)->num`)) (concl th)) + then ASSUME_TAC th THEN ASSUME_TAC(AP_TERM `LENGTH:(int32 list)->num` th) + else failwith "not filter_eq") THEN + + (* Abbreviate the FILTER length to prevent DISCARD from seeing s26 ref *) + ABBREV_TAC `newlen = LENGTH(FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s26:int256) (0,32):int32; + word_subword (read YMM3 s26) (32,32); + word_subword (read YMM3 s26) (64,32); + word_subword (read YMM3 s26) (96,32); + word_subword (read YMM3 s26) (128,32); + word_subword (read YMM3 s26) (160,32); + word_subword (read YMM3 s26) (192,32); + word_subword (read YMM3 s26) (224,32)])` THEN + + (* The hypothesis `newlen = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) inlist))` + already exists from the SIMD subgoal proof. It's state-free and + survives DISCARD. No need to re-derive it. *) + + (* Derive FILTER = REJ_SAMPLE BEFORE ABBREVs, while all state hyps exist. + The SIMD subgoal proved LENGTH equality. Now prove FILTER equality + by using read YMM3 s26 = read YMM3 s29 (unchanged through s27-s29). *) + SUBGOAL_THEN + `FILTER (\c:int32. val c < 8380417) + [word_subword (read YMM3 s29:int256) (0,32):int32; + word_subword (read YMM3 s29) (32,32); + word_subword (read YMM3 s29) (64,32); + word_subword (read YMM3 s29) (96,32); + word_subword (read YMM3 s29) (128,32); + word_subword (read YMM3 s29) (160,32); + word_subword (read YMM3 s29) (192,32); + word_subword (read YMM3 s29) (224,32)] = + REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list))` + ASSUME_TAC THENL + [(* Use the FILTER=REJ_SAMPLE hypothesis (s26 version) to reduce to + FILTER P [s29 lanes] = FILTER P [s26 lanes], then ASM_REWRITE closes + because read YMM3 s29 = read YMM3 s26 (same VPAND EXPR). *) + FIRST_X_ASSUM(SUBST1_TAC o SYM o check (fun th -> + can (find_term (fun t -> try fst(dest_const t) = "FILTER" with _ -> false)) (concl th) && + can (find_term (fun t -> try fst(dest_const t) = "REJ_SAMPLE" with _ -> false)) (concl th) && + is_eq(concl th) && + not(can (find_term ((=) `LENGTH:(int32 list)->num`)) (concl th)))) THEN + ASM_REWRITE_TAC[]; + ALL_TAC] THEN + + (* Save the 8 bounds val(word_subword (read YMM3 s29) (k,32)) < 8388608 + BEFORE ABBREV substitutes coeffs_ymm3. Otherwise these bounds are + consumed as intermediate lemmas during CMP_MASK discharge and the + later VPERMD block's Step F picker (which looks for + `word_subword coeffs_ymm3 (k,32) ... < 8388608`) fails with Not_found. *) + SUBGOAL_THEN + `val(word_subword (read YMM3 s29:int256) (0,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (32,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (64,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (96,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (128,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (160,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (192,32):int32) < 8388608 /\ + val(word_subword (read YMM3 s29:int256) (224,32):int32) < 8388608` + STRIP_ASSUME_TAC THENL + [FIRST_ASSUM(fun th -> + let c = concl th in + if is_eq c && + (try fst(dest_const(fst(strip_comb(rhs c)))) = "word_and" with _ -> false) && + (try let ops,args = strip_comb(lhs c) in + fst(dest_const ops) = "read" && + List.length args = 2 && + fst(dest_const(List.hd args)) = "YMM3" + with _ -> false) + then SUBST1_TAC th + else failwith "no YMM3 word_and") THEN + REWRITE_TAC[WORD_SUBWORD_AND] THEN + CONV_TAC(DEPTH_CONV(WORD_SIMPLE_SUBWORD_CONV ORELSEC WORD_NUM_RED_CONV)) THEN + REPEAT CONJ_TAC THEN + MATCH_MP_TAC(ARITH_RULE `n <= 8388607 ==> n < 8388608`) THEN + REWRITE_TAC[VAL_WORD_AND_WORD_LE]; + ALL_TAC] THEN + + (* Ghost state: ABBREV key s29 values before DISCARD kills them. *) + ABBREV_TAC `coeffs_ymm3:int256 = read YMM3 s29` THEN + ABBREV_TAC `cmp_mask:int64 = read R8 s29` THEN + ABBREV_TAC `table_entry:int64 = + read (memory :> bytes64 (word_add table (word (8 * val (cmp_mask:int64))))) s29` THEN + + (* Preserve cmp_mask ↔ coefficient comparison relationship. + cmp_mask encodes which coefficients pass val < Q via VMOVMSKPS. + This connects cmp_mask to the FILTER predicate for the brute force. *) + SUBGOAL_THEN + `val(cmp_mask:int64) = + bitval(val(word_subword (coeffs_ymm3:int256) (0,32):int32) < 8380417) + + 2 * bitval(val(word_subword (coeffs_ymm3:int256) (32,32):int32) < 8380417) + + 4 * bitval(val(word_subword (coeffs_ymm3:int256) (64,32):int32) < 8380417) + + 8 * bitval(val(word_subword (coeffs_ymm3:int256) (96,32):int32) < 8380417) + + 16 * bitval(val(word_subword (coeffs_ymm3:int256) (128,32):int32) < 8380417) + + 32 * bitval(val(word_subword (coeffs_ymm3:int256) (160,32):int32) < 8380417) + + 64 * bitval(val(word_subword (coeffs_ymm3:int256) (192,32):int32) < 8380417) + + 128 * bitval(val(word_subword (coeffs_ymm3:int256) (224,32):int32) < 8380417)` + ASSUME_TAC THENL + [(* Use CMP_MASK_CORRECT: rewrite H31 (cmp_mask ABBREV) using GSYM H30 + (coeffs_ymm3 ABBREV) to replace the VPAND chain with coeffs_ymm3, + then apply the lemma directly. *) + FIRST_ASSUM(fun h30 -> + if is_eq(concl h30) && is_var(lhs(concl h30)) && + (try fst(dest_var(lhs(concl h30))) = "coeffs_ymm3" with _ -> false) && + (try fst(dest_const(fst(strip_comb(rhs(concl h30))))) = "word_and" + with _ -> false) + then + FIRST_ASSUM(fun h31 -> + if is_eq(concl h31) && is_var(lhs(concl h31)) && + (try fst(dest_var(lhs(concl h31))) = "cmp_mask" with _ -> false) && + (try fst(dest_const(fst(strip_comb(rhs(concl h31))))) = "word_zx" + with _ -> false) + then + SUBST1_TAC(REWRITE_RULE[GSYM h30] h31) + else failwith "not h31") + else failwith "not h30") THEN + (* Normalize the bit-31/word_subword/word-of-sum shape to match + CMP_MASK_CORRECT's word_of_bits form: first collapse the + bit 31 (word_subword x (k,32)) patterns via GSYM BIT_SUBWORD_256 + (so we see bit (32k+31) of the nested word_join), then fold the + word(sum of bitvals) via GSYM VMOVMSKPS_BYTE_EQ into word_of_bits. *) + REWRITE_TAC[GSYM BIT_SUBWORD_256; GSYM VMOVMSKPS_BYTE_EQ] THEN + MATCH_MP_TAC(ISPECL [ + `word_subword (coeffs_ymm3:int256) (0,32):int32`; + `word_subword (coeffs_ymm3:int256) (32,32):int32`; + `word_subword (coeffs_ymm3:int256) (64,32):int32`; + `word_subword (coeffs_ymm3:int256) (96,32):int32`; + `word_subword (coeffs_ymm3:int256) (128,32):int32`; + `word_subword (coeffs_ymm3:int256) (160,32):int32`; + `word_subword (coeffs_ymm3:int256) (192,32):int32`; + `word_subword (coeffs_ymm3:int256) (224,32):int32` + ] CMP_MASK_CORRECT) THEN + (* Prove val(word_subword coeffs_ymm3 (k,32)) < 8388608 for each k. + coeffs_ymm3 = word_and(big)(mask) so word_subword distributes, + mask subword = word 8388607, then VAL_WORD_AND_WORD_LE gives bound. *) + FIRST_ASSUM(fun h30 -> + if is_eq(concl h30) && is_var(lhs(concl h30)) && + (try fst(dest_var(lhs(concl h30))) = "coeffs_ymm3" with _ -> false) && + (try fst(dest_const(fst(strip_comb(rhs(concl h30))))) = "word_and" + with _ -> false) + then SUBST1_TAC h30 + else failwith "") THEN + REWRITE_TAC[WORD_SUBWORD_AND] THEN + CONV_TAC(DEPTH_CONV(WORD_SIMPLE_SUBWORD_CONV ORELSEC WORD_NUM_RED_CONV)) THEN + REPEAT CONJ_TAC THEN + MATCH_MP_TAC(ARITH_RULE `n <= 8388607 ==> n < 8388608`) THEN + REWRITE_TAC[VAL_WORD_AND_WORD_LE]; + ALL_TAC] THEN + + (* val(word curlen) = curlen — used by memaccess_inbounds for Store(res+4*val(word curlen),32). *) + SUBGOAL_THEN `val(word curlen:int64) = curlen` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN + MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen <= 248` THEN ARITH_TAC; + ALL_TAC] THEN + (* val(word(4*curlen)) = 4*curlen — for Store address word_add res (word(4*curlen)). *) + SUBGOAL_THEN `val(word(4 * curlen):int64) = 4 * curlen` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN + MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `curlen <= 248` THEN ARITH_TAC; + ALL_TAC] THEN + (* val(word(24*i)) = 24*i — used by memaccess_inbounds for Load(buf+24*i,32). *) + SUBGOAL_THEN `val(word(24*i):int64) = 24*i` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN + MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + ALL_TAC] THEN + (* val(word(1*24*i)) = 24*i — same, but with `1 *` prefix that surfaces + in the bytes256 read of buf+1*24*i. *) + SUBGOAL_THEN `val(word(1 * 24 * i):int64) = 24*i` ASSUME_TAC THENL + [REWRITE_TAC[MULT_CLAUSES; VAL_WORD; DIMINDEX_64] THEN + MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; + ALL_TAC] THEN + (* Bound on val(cmp_mask) needed by memaccess_inbounds discharge for the + table EventLoad event (address = table + 8*val cmp_mask, size = 8). *) + SUBGOAL_THEN `val(cmp_mask:int64) <= 255` ASSUME_TAC THENL + [FIRST_ASSUM(fun th -> + if is_eq(concl th) && + (try fst(dest_var(lhs(concl th))) = "val" with _ -> false) || + (try let l = lhs(concl th) in + fst(dest_const(rator l)) = "val" && + fst(dest_var(rand l)) = "cmp_mask" + with _ -> false) + then SUBST1_TAC th + else failwith "no cmp_mask val eq") THEN + MAP_EVERY (fun k -> MP_TAC(SPEC k BITVAL_BOUND)) + [`val(word_subword (coeffs_ymm3:int256) (0,32):int32) < 8380417`; + `val(word_subword (coeffs_ymm3:int256) (32,32):int32) < 8380417`; + `val(word_subword (coeffs_ymm3:int256) (64,32):int32) < 8380417`; + `val(word_subword (coeffs_ymm3:int256) (96,32):int32) < 8380417`; + `val(word_subword (coeffs_ymm3:int256) (128,32):int32) < 8380417`; + `val(word_subword (coeffs_ymm3:int256) (160,32):int32) < 8380417`; + `val(word_subword (coeffs_ymm3:int256) (192,32):int32) < 8380417`; + `val(word_subword (coeffs_ymm3:int256) (224,32):int32) < 8380417`] THEN + ARITH_TAC; + ALL_TAC] THEN + (* val(word(8 * val cmp_mask)) = 8 * val cmp_mask — needs val cmp_mask <= 255 + so 8*val cmp_mask <= 2040 < 2^64. *) + SUBGOAL_THEN `val(word(8 * val(cmp_mask:int64)):int64) = 8 * val cmp_mask` + ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_64] THEN + MATCH_MP_TAC MOD_LT THEN + UNDISCH_TAC `val(cmp_mask:int64) <= 255` THEN ARITH_TAC; + ALL_TAC] THEN + + (* Use KEEP_EVENTS variant so the events chain (whose POPCNT operand + transitively references earlier states) isn't erased. *) + DISCARD_OLDSTATE_KEEP_EVENTS_TAC "s29" THEN CLARIFY_TAC THEN + (* Step 30-32 WITHOUT discard to keep VPERMD hypothesis chain *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (30--32) THEN + SUBGOAL_THEN + `val(read YMM3 s32:int256) MOD 2 EXP (32 * newlen) = + num_of_wordlist(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` + ASSUME_TAC THENL + [(* VPERMD: apply MLDSA_VPERMD_BRIDGE (proven in defs_extra.ml) + — replaces the former 256-case brute force, eliminating 255 cheats. *) + (* Step A: derive val(table_entry) = (table DIV 2^(64*val cmp_mask)) MOD 2^64 *) + SUBGOAL_THEN + `val(table_entry:int64) = + (num_of_wordlist mldsa_rej_uniform_table DIV + 2 EXP (64 * val(cmp_mask:int64))) MOD 2 EXP 64` + ASSUME_TAC THENL + [SUBGOAL_THEN + `val(read(memory :> bytes64(word_add table (word(8 * val(cmp_mask:int64))))) s32 :int64) = + (num_of_wordlist mldsa_rej_uniform_table DIV 2 EXP (64 * val cmp_mask)) MOD 2 EXP 64` + MP_TAC THENL + [MATCH_MP_TAC TABLE_ENTRY_FROM_MEMORY THEN CONJ_TAC THENL + [ASM_REWRITE_TAC[]; + FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `bitval`)) (concl th) && is_eq(concl th) && + (try fst(dest_var(rand(lhs(concl th)))) = "cmp_mask" with _ -> false) + then SUBST1_TAC th else failwith "") THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (0,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (32,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (64,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (96,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (128,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (160,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (192,32):int32) < 8380417` BITVAL_BOUND) THEN + MP_TAC(SPEC `val(word_subword (coeffs_ymm3:int256) (224,32):int32) < 8380417` BITVAL_BOUND) THEN + ARITH_TAC]; + ASM_REWRITE_TAC[]]; ALL_TAC] THEN + (* Step B: substitute read YMM3 s32 into goal (exposes the VPERMD expansion). *) + FIRST_X_ASSUM (fun th -> + let s = string_of_term (concl th) in + let n = String.length s in + let contains needle = + let k = String.length needle in + let rec scan i = i + k <= n && (String.sub s i k = needle || scan (i+1)) in + scan 0 in + if contains "read YMM3 s32" && is_eq(concl th) && not(contains "MAYCHANGE") + then GEN_REWRITE_TAC (LAND_CONV o ONCE_DEPTH_CONV) [th] THEN ASSUME_TAC th + else failwith "not ymm3 s32") THEN + (* Step C: rewrite (32 * newlen) → (32 * bitval_sum) via newlen = LENGTH(FILTER) + and FILTER=REJ_SAMPLE; also convert RHS REJ_SAMPLE → FILTER. *) + (fun (asl, w) -> + let contains_s needle s = + let n = String.length needle in let m = String.length s in + let rec scan i = i + n <= m && (String.sub s i n = needle || scan (i+1)) in + scan 0 in + let b k = + let needle = Printf.sprintf "word_subword coeffs_ymm3 (%d,32)" k in + snd(List.find (fun (_,th) -> + let s = string_of_term (concl th) in + contains_s needle s && contains_s "< 8388608" s && not(contains_s "==>" s)) asl) in + let bounds = CONJ (b 0) (CONJ (b 32) (CONJ (b 64) (CONJ (b 96) (CONJ (b 128) (CONJ (b 160) (CONJ (b 192) (b 224))))))) in + let flt_spec = ISPECL [ + `word_subword (coeffs_ymm3:int256) (0,32):int32`; + `word_subword (coeffs_ymm3:int256) (32,32):int32`; + `word_subword (coeffs_ymm3:int256) (64,32):int32`; + `word_subword (coeffs_ymm3:int256) (96,32):int32`; + `word_subword (coeffs_ymm3:int256) (128,32):int32`; + `word_subword (coeffs_ymm3:int256) (160,32):int32`; + `word_subword (coeffs_ymm3:int256) (192,32):int32`; + `word_subword (coeffs_ymm3:int256) (224,32):int32` + ] FILTER_LENGTH_8 in + let length_filter_eq = MP flt_spec bounds in + let newlen_def = snd(List.find (fun (_, th) -> + is_eq(concl th) && is_var(lhs(concl th)) && + (try fst(dest_var(lhs(concl th))) = "newlen" with _ -> false)) asl) in + let filter_rej_eq = snd(List.find (fun (_, th) -> + let s = string_of_term (concl th) in + contains_s "FILTER" s && contains_s "REJ_SAMPLE" s && is_eq(concl th) && + not(contains_s "LENGTH" s) && not(contains_s "==>" s)) asl) in + let newlen_bitval = TRANS (TRANS newlen_def + (AP_TERM `LENGTH:int32 list -> num` (SYM filter_rej_eq))) length_filter_eq in + REWRITE_TAC[newlen_bitval; SYM filter_rej_eq] (asl, w)) THEN + (* Step D: fold raw memory read back to table_entry, then collapse word_zx(word_zx ...) via + WORD_SIMPLE_SUBWORD_CONV to expose word_subword table_entry (k,3). *) + (fun (asl, w) -> + let contains_s needle s = + let n = String.length needle in let m = String.length s in + let rec scan i = i + n <= m && (String.sub s i n = needle || scan (i+1)) in + scan 0 in + let cm_sym = + let th = snd(List.find (fun (_, th) -> + is_eq(concl th) && + (try fst(dest_var(rand(lhs(concl th)))) = "cmp_mask" with _ -> false) && + contains_s "bitval" (string_of_term (concl th))) asl) in + SYM th in + let te_eqs = List.filter_map (fun (_, th) -> + let s = string_of_term (concl th) in + if is_eq(concl th) && contains_s "table_entry" s && contains_s "bytes64" s + then Some th else None) asl in + (REWRITE_TAC[cm_sym] THEN REWRITE_TAC te_eqs THEN + CONV_TAC(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV)) (asl, w)) THEN + (* Step E: apply MLDSA_VPERMD_BRIDGE specialized to coeffs_ymm3 and table_entry. *) + MATCH_MP_TAC (ISPECL [`coeffs_ymm3:int256`; `table_entry:int64`] MLDSA_VPERMD_BRIDGE) THEN + (* Step F: discharge the antecedent using targeted rewrites (bounds + te_val + cm_sym). + Full ASM_REWRITE_TAC hangs on the large assumption list, but this focused + set closes the 9 antecedent conjuncts in ~2s. *) + W(fun (asl,_) -> + let contains_s needle s = + let n = String.length needle in let m = String.length s in + let rec scan i = i + n <= m && (String.sub s i n = needle || scan (i+1)) in + scan 0 in + let b k = + let needle = Printf.sprintf "word_subword coeffs_ymm3 (%d,32)" k in + snd(List.find (fun (_,th) -> + let s = string_of_term (concl th) in + contains_s needle s && contains_s "< 8388608" s && not(contains_s "==>" s)) asl) in + let cm_sym = + let th = snd(List.find (fun (_, th) -> + is_eq(concl th) && + (try fst(dest_var(rand(lhs(concl th)))) = "cmp_mask" with _ -> false) && + contains_s "bitval" (string_of_term (concl th))) asl) in + SYM th in + let te_val = snd(List.find (fun (_, th) -> + is_eq(concl th) && + (try let l = lhs(concl th) in + fst(dest_comb l) = `val:int64->num` && + fst(dest_var(rand l)) = "table_entry" + with _ -> false) && + contains_s "DIV" (string_of_term (concl th))) asl) in + REWRITE_TAC [b 0; b 32; b 64; b 96; b 128; b 160; b 192; b 224; te_val; cm_sym]); + ALL_TAC] THEN + (* VSTEPS for all 3 steps to preserve bytes256 + VPERMD hyps *) + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (33--35) THEN + + (* (e) COND_CASES_TAC: continue (i+1 < N) vs exit (~(i+1 < N)) *) + COND_CASES_TAC THENL + [(* i+1 < N: continue looping *) + (* Derive new region memory content BEFORE ENSURES consumes state. + From the bytes256 write hypothesis (VMOVDQU step), derive: + read(memory :> bytes(addr, 32)) sN = val(bytes256 RHS) + with address normalized (val(word curlen) → curlen). + This is used by VPERMD_MEMORY_BRIDGE in the memory store goal. *) + (fun (asl,w) -> + try + (* Find bytes256 hyp with s35 and res in address *) + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (fun t -> try fst(dest_const t) = "bytes256" with _ -> false)) (lhs(concl th)) && + can (find_term (fun t -> try fst(dest_var t) = "res" with _ -> false)) (lhs(concl th)) && + can (find_term (fun t -> try fst(dest_var t) = "s35" with _ -> false)) (lhs(concl th)) && + not(can (find_term (fun t -> try fst(dest_const t) = "MAYCHANGE" with _ -> false)) (concl th)) && + not(can (find_term (fun t -> try fst(dest_const t) = "events" with _ -> false)) (lhs(concl th)))) asl) in + (* Find read YMM3 s35 = to get clean RHS *) + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let ymm3_s35 = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_var "s35")) (lhs(concl th)) && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + not(can (find_term (has_const "MOD")) (concl th)) && + not(can (find_term (has_const "bytes256")) (concl th))) asl) in + (* Chain: bytes256 s35 = = YMM3 s35 by transitivity *) + let b256_ymm3 = TRANS b256_th (SYM ymm3_s35) in + (* Normalize address: val(word curlen) → curlen *) + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let mk_norm dim_thm dim_val = + let vwe = CONV_RULE NUM_REDUCE_CONV (REWRITE_RULE[dim_thm] (INST_TYPE [dim_val,`:N`] VAL_WORD_EQ)) in + let lt = CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) curlen_bound) in + try MATCH_MP vwe lt with _ -> + let lt64 = CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) curlen_bound) in + MATCH_MP vwe lt64 in + let curlen_norm_32 = mk_norm DIMINDEX_32 `:32` in + let curlen_norm_64 = mk_norm DIMINDEX_64 `:64` in + let b256_norm = REWRITE_RULE[curlen_norm_32; curlen_norm_64] b256_ymm3 in + (* Convert: val(bytes256 addr s35) = val(read YMM3 s35) + address normalize *) + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + (* Result: read(memory :> bytes(addr,32)) s35 = val(read YMM3 s35) *) + ASSUME_TAC bytes32_eq (asl,w) + with e -> + Printf.printf "pre-ENSURES bytes32 setup failed: %s\n%!" (Printexc.to_string e); + ALL_TAC (asl,w)) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + (* Establish iteration bounds *) + SUBGOAL_THEN `8 * (i + 1) <= LENGTH(inlist:(24 word)list)` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + (* Use the SIMD↔spec result: newlen = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8))) + ABBREV_TAC replaced FILTER... with newlen in this hypothesis *) + FIRST_X_ASSUM(SUBST1_TAC o check (fun th -> + is_eq(concl th) && + can (find_term ((=) `newlen:num`)) (concl th) && + can (find_term (fun t -> + try fst(dest_const t) = "REJ_SAMPLE" with _ -> false)) (concl th))) THEN + (* Apply SIMD_ITERATION_BRIDGE to split REJ_SAMPLE across iterations *) + MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; `curlen:num`] + SIMD_ITERATION_BRIDGE) THEN + ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + ASM_REWRITE_TAC[LENGTH_APPEND; NUM_OF_WORDLIST_APPEND] THEN + (* Clean state hypotheses before word arithmetic — keep MAYCHANGE and memory *) + DISCARD_ASSUMPTIONS_TAC (fun th -> + let c = concl th in + (can (term_match [] `read (x:(x86state,num)component) (s:x86state) = (y:num)`) c && + not (can (find_term (fun t -> try fst(dest_const t) = "memory" with _ -> false)) c)) || + can (term_match [] `bytes_loaded (x:x86state) (y:int64) (z:byte list)`) c || + can (term_match [] `read (x:(x86state,bool)component) (s:x86state) <=> (y:bool)`) c) THEN + ABBREV_TAC `lenrej = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) inlist))` THEN + REPEAT CONJ_TAC THENL + [(* RAX: word(curlen + lenrej) — word arithmetic. + Use targeted UNDISCH (not ASM_ARITH_TAC — hangs on ~200 asl). *) + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lenrej <= 8` THEN + SPEC_TAC(`lenrej:num`, `l:num`) THEN GEN_TAC THEN + SPEC_TAC(`curlen:num`, `c:num`) THEN GEN_TAC THEN + REPEAT DISCH_TAC THEN + SUBGOAL_THEN `c < 4294967296 /\ c < 18446744073709551616 /\ + l < 4294967296 /\ l < 18446744073709551616 /\ + c + l < 4294967296 /\ c + l < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `c <= 248` THEN UNDISCH_TAC `l <= 8` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + (* RCX: word(24*(i+1)) — word arithmetic *) + REWRITE_TAC[ARITH_RULE `24 * (i + 1) = 24 * i + 24`] THEN + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; VAL_WORD; + DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`, `n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC; + (* Memory store: use VPERMD_MEMORY_BRIDGE + We have (from PRE-ENSURES): + read(memory :> bytes(addr, 32)) s35 = val(read YMM3 sN) + And (from VPERMD): + val(read YMM3 sN) MOD 2^(32*lenrej) = num_of_wordlist(REJ_SAMPLE(...)) + VPERMD_MEMORY_BRIDGE chains these to close the sub-read goal. *) + REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lenrej) = 4 * curlen + 4 * lenrej`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + (* Goal: read(bytes(addr, 4*lenrej)) s35 = num_of_wordlist(REJ_SAMPLE(...)) + Use VPERMD_MEMORY_BRIDGE with the PRE-ENSURES bytes32 hypothesis. + First find the hypothesis, then build the full closing tactic. *) + (fun (asl,w) -> + (* Find bytes32 hyp, VPERMD MOD hyp, lenrej bound, then forward-chain *) + try + (* 1. bytes32 hypothesis: read(bytes(addr,32)) s35 = vr *) + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let bytes32_hyp = try snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (fun t -> try dest_numeral t = Num.num_of_int 32 with _ -> false)) (lhs(concl th)) && + can (find_term (fun t -> try fst(dest_var t) = "s35" with _ -> false)) (lhs(concl th)) && + can (find_term (fun t -> try fst(dest_var t) = "res" with _ -> false)) (lhs(concl th)) && + can (find_term (fun t -> try fst(dest_const t) = "bytes" with _ -> false)) (lhs(concl th)) && + not(can (find_term (fun t -> try fst(dest_const t) = "bytes256" with _ -> false)) (lhs(concl th))) && + not(can (find_term (fun t -> try fst(dest_const t) = "events" with _ -> false)) (lhs(concl th)))) asl) with Not_found -> (Printf.printf "bytes32_hyp Not_found\n%!"; raise Not_found) in + (* Find newlen = lenrej hypothesis *) + let newlen_eq = try snd(List.find (fun (_,th) -> + try is_eq(concl th) && has_var "newlen" (lhs(concl th)) && + has_var "lenrej" (rhs(concl th)) + with _ -> false) asl) with Not_found -> (Printf.printf "newlen_eq Not_found\n%!"; raise Not_found) in + (* Find VPERMD MOD hyp: val(YMM3 sN) MOD 2^(32*newlen) = num_of_wordlist(...) + May be for s34 or s33 — find the most recent one *) + let vpermd_hyp_raw = try snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_var "newlen")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th)) asl) with Not_found -> (Printf.printf "vpermd_hyp_raw Not_found\n%!"; raise Not_found) in + (* Normalize: replace newlen with lenrej *) + let vpermd_hyp_1 = REWRITE_RULE[newlen_eq] vpermd_hyp_raw in + (* The VPERMD hyp may use a different state (s34) than bytes32 (s35). + Bridge: find read YMM3 s35 = E and read YMM3 s34 = E, chain them. *) + let vpermd_hyp = try + (* Find read YMM3 sN = — require int256 RHS and YMM3 on LHS *) + let is_ymm3_read th = + is_eq(concl th) && + type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let ymm35 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var "s35")) (lhs(concl th))) asl) in + let ymm34 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var "s34")) (lhs(concl th))) asl) in + (* read YMM3 s35 = E, read YMM3 s34 = E => read YMM3 s35 = read YMM3 s34 *) + let ymm_eq = TRANS ymm35 (SYM ymm34) in + let val_eq = AP_TERM `val:int256->num` ymm_eq in + (* Rewrite s34 → s35 in the VPERMD hypothesis *) + REWRITE_RULE[GSYM val_eq] vpermd_hyp_1 + with _ -> + vpermd_hyp_1 in + (* 3. lenrej <= 8: directly available *) + let lenrej_bound = try snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + has_var "lenrej" (lhand(concl th)) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) with Not_found -> (Printf.printf "lenrej_bound Not_found\n%!"; raise Not_found) in + (* Forward chain: MATCH_MP VPERMD_MEMORY_BRIDGE (bytes32 /\ mod /\ bound) *) + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_hyp (CONJ vpermd_hyp lenrej_bound)) in + REWRITE_TAC[bridge] (asl,w) + with e -> + Printf.printf "memstore bridge: %s\n%!" (Printexc.to_string e); + failwith "memstore bridge derivation failed"); + W(fun (asl,w) -> + try + let _, body = dest_exists w in + if can (find_term (fun t -> + try fst(dest_const t) = "memaccess_inbounds" with _ -> false)) body + then + (DISCHARGE_MEMSAFE_ASM_TAC ORELSE + (try + let _, ainbds_th = List.find (fun (_,th) -> + let c = concl th in + try fst(dest_const(rator(rator(rator c)))) = "memaccess_inbounds" + with _ -> false) asl in + let e_var = rand(rator(rator(concl ainbds_th))) in + EXISTS_TAC e_var THEN ASM_REWRITE_TAC[] + with _ -> NO_TAC)) + else NO_TAC + with _ -> NO_TAC)]; + + (* ~(i+1 < N): exit to pc+181. + Approach: instead of DISJ_CASES on the outer disjunct, first derive + N = i+1, then ASM_CASES on `248 < curlen + lenrej`: + * if true: count-exit fires (JAE at s37 to pc+181) — same as old J2 + * if false: offset-exit path — VSTEPS 38-39 do CMP ecx/JA exit + This avoids the artificial J1/J2 split that required a separate + offset-exit proof. *) + SUBGOAL_THEN `N = i + 1` ASSUME_TAC THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN ARITH_TAC; + ALL_TAC] THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (36--37) THEN + FIRST_X_ASSUM(DISJ_CASES_TAC o check (is_disj o concl)) THENL + [(* J1: offset exit (832 < 24 * (N + 1) disjunct holds). + Split on whether count-exit ALSO fires: + * J1-A: 248 < curlen + lr → JAE at s37 fires directly, same as J2. + * J1-B: curlen + lr <= 248 → JAE falls through, CMP ecx/JA at s38-39 + fires because 808 < 24*(i+1) (from disjunct + N=i+1). *) + ASM_CASES_TAC + `248 < curlen + LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` + THENL + [(* J1-A: JAE at s37 fires. Derive J2's precondition, run J2 body. *) + SUBGOAL_THEN + `248 < LENGTH(REJ_SAMPLE(SUB_LIST(0,8 * N) (inlist:(24 word)list)))` + ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN + ASM_REWRITE_TAC[ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + ADD_CLAUSES]; + ALL_TAC] THEN + (* J1-A body is identical to J2 body; run through it. + Must keep this in sync if J2 body changes. *) + SUBGOAL_THEN `newlen <= 8` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + (fun (asl,w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "bytes256")) (lhs(concl th)) && + not(can (find_term (has_const "MAYCHANGE")) (concl th))) asl) in + let b256_state = find_term (fun t -> + try let n = fst(dest_var t) in + String.length n > 1 && n.[0] = 's' && type_of t = `:x86state` + with _ -> false) (lhs(concl b256_th)) in + let b256_state_name = fst(dest_var b256_state) in + let ymm_th = snd(find (fun (_,th) -> + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + can (find_term (has_var b256_state_name)) (lhs(concl th))) asl) in + let b256_ymm3 = TRANS b256_th (SYM ymm_th) in + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let vwe32 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_32] (INST_TYPE [`:32`,`:N`] VAL_WORD_EQ)) in + let vwe64 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_64] (INST_TYPE [`:64`,`:N`] VAL_WORD_EQ)) in + let cn32 = MATCH_MP vwe32 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) + curlen_bound)) in + let cn64 = MATCH_MP vwe64 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) + curlen_bound)) in + let b256_norm = REWRITE_RULE[cn32; cn64] b256_ymm3 in + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + let vpermd_raw = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th) && + can (find_term (has_var "newlen")) (concl th)) asl) in + let is_ymm3_read th = + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let vpermd_states = List.filter (fun v -> + let n = fst(dest_var v) in String.length n > 1 && n.[0] = 's' && + type_of v = `:x86state`) (frees(concl vpermd_raw)) in + let vp_state_name = fst(dest_var(List.hd vpermd_states)) in + let vpermd = try + let ymm_b32 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var b256_state_name)) + (lhs(concl th))) asl) in + let ymm_vp = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var vp_state_name)) + (lhs(concl th))) asl) in + let ymm_eq = TRANS ymm_b32 (SYM ymm_vp) in + let val_eq = AP_TERM `val:int256->num` ymm_eq in + REWRITE_RULE[GSYM val_eq] vpermd_raw + with _ -> vpermd_raw in + let newlen_bound = snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + (has_var "newlen" (lhand(concl th))) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) in + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_eq (CONJ vpermd newlen_bound)) in + ASSUME_TAC bridge (asl,w) + with _ -> failwith "J1-A PRE-ENSURES") THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `N = i + 1` (fun th -> + REWRITE_TAC[th; ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + ARITH_RULE `24 * (i + 1) = 24 * i + 24`]) THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + NUM_OF_WORDLIST_APPEND] THEN + REWRITE_TAC[ADD_CLAUSES] THEN + ABBREV_TAC + `lr = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + SUBGOAL_THEN `lr <= 8` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN REWRITE_TAC[REJ_SAMPLE] THEN + TRANS_TAC LE_TRANS `LENGTH(MAP (\x:24 word. word(val x MOD 2 EXP 23):int32) (SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + REWRITE_TAC[LENGTH_FILTER; LENGTH_MAP; LENGTH_SUB_LIST] THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `248 < curlen + lr` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + (* Resolve `read RIP s37 = word(pc + 181)` (JAE fires) and rewrite + into events to eliminate the COND that would otherwise stall + DISCHARGE_MEMSAFE_ASM_TAC's existential search. *) + SUBGOAL_THEN `read RIP s37 = word(pc + 181):int64` ASSUME_TAC THENL + [FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `read RIP s37`)) (concl th) && + is_eq(concl th) + then SUBST1_TAC th else failwith "") THEN + MATCH_MP_TAC(TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC; + ALL_TAC] THEN + (* Rewrite events chain: derive `(if COND_s37 then word(pc+181) else + word(pc+111)) = word(pc+181)` by transitivity from the COND-laden + read RIP s37 hypothesis (still in asl) and the clean fact we just + proved, then RULE_ASSUM with it. This eliminates the COND from + EventJump entries in the events chain. *) + (fun (asl,w) -> + try + let cond_th = snd(List.find (fun (_,th) -> + let c = concl th in + is_eq c && + can (find_term ((=) `read RIP s37`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c)) asl) in + let clean_th = snd(List.find (fun (_,th) -> + concl th = `read RIP s37 = word(pc + 181):int64`) asl) in + (* cond_th: read RIP s37 = (if c then a else b). + clean_th: read RIP s37 = word(pc+181). + Want: (if c then a else b) = word(pc+181). *) + let cond_eq_clean = TRANS (SYM cond_th) clean_th in + Printf.printf "DBG: COND rewrite TRANS produced %s\n%!" (string_of_term (concl cond_eq_clean)); + RULE_ASSUM_TAC (REWRITE_RULE [cond_eq_clean]) (asl,w) + with e -> Printf.printf "DBG: COND rewrite failed: %s\n%!" (Printexc.to_string e); ALL_TAC (asl,w)) THEN + TRY (FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + is_eq c && can (find_term ((=) `read RIP s37`)) c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c)))) THEN + FIRST_X_ASSUM(SUBST_ALL_TAC) THEN + (fun (asl, w) -> + try + let newlen_eq_lr = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + (try fst(dest_var(lhs c)) = "newlen" with _ -> false) && + (try fst(dest_var(rhs c)) = "lr" with _ -> false)) asl) in + RULE_ASSUM_TAC (REWRITE_RULE [newlen_eq_lr]) (asl, w) + with _ -> ALL_TAC (asl, w)) THEN + DISCARD_ASSUMPTIONS_TAC (fun th -> + let c = concl th in + let fvs = frees c in + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + not(is_eq c && + can (find_term (has_const "read")) c && + can (find_term (has_const "bytes")) c) && + not(can (find_term (has_var "cmp_mask")) c && + is_binary "<=" c) && + not(can (find_term (has_const "memaccess_inbounds")) c) && + (not (forall (fun v -> type_of v = `:num`) fvs) || + exists (fun v -> let n = fst(dest_var v) in + n = "N" || n = "newlen" || n = "curlist") fvs || + is_forall c)) THEN + REPEAT CONJ_TAC THEN + (FIRST_ASSUM ACCEPT_TAC ORELSE + (MATCH_MP_TAC(TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `248 <= curlen + lr` ASSUME_TAC THENL + [UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(curlen + lr) - 248 < 18446744073709551616` + ASSUME_TAC THENL + [UNDISCH_TAC `curlen + lr < 18446744073709551616` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC) ORELSE + (REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC) ORELSE + (REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`,`n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ + n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC) ORELSE + (REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lr) = 4 * curlen + 4 * lr`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + ASM_REWRITE_TAC[] THEN NO_TAC) ORELSE + (W(fun (_,w) -> + if (try let n = fst(dest_var(fst(dest_exists w))) in + n = "e_acc'" || n = "e_acc" || String.length n >= 5 && + String.sub n 0 5 = "e_acc" + with _ -> false) + then DISCHARGE_MEMSAFE_ASM_TAC + else NO_TAC)) ORELSE + ASM_REWRITE_TAC[]); + + (* J1-B: JAE at s37 falls through to pc+111. VSTEPS 38-39 do CMP ecx/JA + and exit to pc+181 because 808 < 24*(i+1) (from offset disjunct). *) + ABBREV_TAC + `lr = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + SUBGOAL_THEN `lr <= 8` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN REWRITE_TAC[REJ_SAMPLE] THEN + TRANS_TAC LE_TRANS `LENGTH(MAP (\x:24 word. word(val x MOD 2 EXP 23):int32) (SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + REWRITE_TAC[LENGTH_FILTER; LENGTH_MAP; LENGTH_SUB_LIST] THEN + ARITH_TAC; + ALL_TAC] THEN + (* Resolve RIP s37 = word(pc + 111) via COND simplification *) + SUBGOAL_THEN `read RIP s37 = word(pc + 111):int64` MP_TAC THENL + [FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `read RIP s37`)) (concl th) && + is_eq(concl th) + then SUBST1_TAC th else failwith "") THEN + MATCH_MP_TAC(TAUT `~p ==> (if p then (a:int64) else b) = b`) THEN + REWRITE_TAC[DE_MORGAN_THM; NOT_CLAUSES; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `~(248 < curlen + lr)` THEN + ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN ASSUME_TAC THEN + FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + can (find_term ((=) `read RIP s37`)) c && is_eq c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c))) THEN + X86_VSTEPS_TAC MLDSA_REJ_UNIFORM_EXEC (38--39) THEN + (* Resolve RIP s39 = word(pc + 181) via JA analysis using BITBLAST_RULE + on the VAL_WORD_ZX/SUB expression. *) + (* Resolve RIP s39 = word(pc + 181) — mirror original proof pattern. *) + SUBGOAL_THEN `read RIP s39 = word(pc + 181):int64` MP_TAC THENL + [FIRST_X_ASSUM(fun th -> + if can (find_term ((=) `read RIP s39`)) (concl th) && + is_eq(concl th) + then SUBST1_TAC th else failwith "") THEN + MATCH_MP_TAC(TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `24 * i < 4294967296 /\ 24 * i < 18446744073709551616 /\ + 24 * i + 24 < 4294967296 /\ + 24 * i + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `24 * i <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `832 < 24 * (N + 1)` THEN + UNDISCH_TAC `N = i + 1` THEN ARITH_TAC; + ALL_TAC] THEN + DISCH_THEN ASSUME_TAC THEN + FIRST_X_ASSUM(K ALL_TAC o check (fun th -> + let c = concl th in + can (find_term ((=) `read RIP s39`)) c && is_eq c && + can (find_term (fun t -> + try fst(dest_const t) = "COND" with _ -> false)) (rhs c))) THEN + (* Rest mirrors J2's body, adapted for s39 *) + SUBGOAL_THEN `newlen <= 8` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + (fun (asl,w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "bytes256")) (lhs(concl th)) && + not(can (find_term (has_const "MAYCHANGE")) (concl th))) asl) in + let b256_state = find_term (fun t -> + try let n = fst(dest_var t) in + String.length n > 1 && n.[0] = 's' && type_of t = `:x86state` + with _ -> false) (lhs(concl b256_th)) in + let b256_state_name = fst(dest_var b256_state) in + let ymm_th = snd(find (fun (_,th) -> + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + can (find_term (has_var b256_state_name)) (lhs(concl th))) asl) in + let b256_ymm3 = TRANS b256_th (SYM ymm_th) in + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let vwe32 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_32] (INST_TYPE [`:32`,`:N`] VAL_WORD_EQ)) in + let vwe64 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_64] (INST_TYPE [`:64`,`:N`] VAL_WORD_EQ)) in + let cn32 = MATCH_MP vwe32 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) + curlen_bound)) in + let cn64 = MATCH_MP vwe64 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) + curlen_bound)) in + let b256_norm = REWRITE_RULE[cn32; cn64] b256_ymm3 in + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + let vpermd_raw = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th) && + can (find_term (has_var "newlen")) (concl th)) asl) in + let is_ymm3_read th = + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let vpermd_states = List.filter (fun v -> + let n = fst(dest_var v) in String.length n > 1 && n.[0] = 's' && + type_of v = `:x86state`) (frees(concl vpermd_raw)) in + let vp_state_name = fst(dest_var(List.hd vpermd_states)) in + let vpermd = try + let ymm_b32 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var b256_state_name)) + (lhs(concl th))) asl) in + let ymm_vp = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var vp_state_name)) + (lhs(concl th))) asl) in + let ymm_eq = TRANS ymm_b32 (SYM ymm_vp) in + let val_eq = AP_TERM `val:int256->num` ymm_eq in + REWRITE_RULE[GSYM val_eq] vpermd_raw + with _ -> vpermd_raw in + let newlen_bound = snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + (has_var "newlen" (lhand(concl th))) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) in + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_eq (CONJ vpermd newlen_bound)) in + ASSUME_TAC bridge (asl,w) + with _ -> failwith "J1-B PRE-ENSURES") THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `N = i + 1` (fun th -> + REWRITE_TAC[th; ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + ARITH_RULE `24 * (i + 1) = 24 * i + 24`]) THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + NUM_OF_WORDLIST_APPEND] THEN + REWRITE_TAC[ADD_CLAUSES] THEN + (* lr already abbreviated in J1-B prologue *) + ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THEN + (FIRST_ASSUM ACCEPT_TAC ORELSE + (REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC) ORELSE + (REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`,`n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ + n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC) ORELSE + (REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lr) = 4 * curlen + 4 * lr`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + (fun (asl, w) -> + try + let newlen_eq_lr = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + (try fst(dest_var(lhs c)) = "newlen" with _ -> false) && + (try fst(dest_var(rhs c)) = "lr" with _ -> false)) asl) in + RULE_ASSUM_TAC (REWRITE_RULE [newlen_eq_lr]) (asl, w) + with _ -> ALL_TAC (asl, w)) THEN + ASM_REWRITE_TAC[] THEN NO_TAC) ORELSE + (W(fun (_,w) -> + if (try let n = fst(dest_var(fst(dest_exists w))) in + n = "e_acc'" || n = "e_acc" || String.length n >= 5 && + String.sub n 0 5 = "e_acc" + with _ -> false) + then DISCHARGE_MEMSAFE_ASM_TAC + else NO_TAC)) ORELSE + ASM_REWRITE_TAC[])]; + (* J2: 248 < LENGTH(REJ_SAMPLE(SUB_LIST(0,8*N))): count exit. + Body is identical to J1-A's (which already has 248 < LENGTH... in + assumptions because J1-A derives it; J2 has it natively from the + disjunct case selection). *) + SUBGOAL_THEN `newlen <= 8` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `i:num`; `curlist:int32 list`; + `curlen:num`] SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [ASM_REWRITE_TAC[] THEN + UNDISCH_TAC `24 * (i + 1) <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + (fun (asl,w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let b256_th = snd(find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "bytes256")) (lhs(concl th)) && + not(can (find_term (has_const "MAYCHANGE")) (concl th))) asl) in + let b256_state = find_term (fun t -> + try let n = fst(dest_var t) in + String.length n > 1 && n.[0] = 's' && type_of t = `:x86state` + with _ -> false) (lhs(concl b256_th)) in + let b256_state_name = fst(dest_var b256_state) in + let ymm_th = snd(find (fun (_,th) -> + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) && + can (find_term (has_var b256_state_name)) (lhs(concl th))) asl) in + let b256_ymm3 = TRANS b256_th (SYM ymm_th) in + let curlen_bound = snd(find (fun (_,th) -> + try concl th = `curlen <= 248` with _ -> false) asl) in + let vwe32 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_32] (INST_TYPE [`:32`,`:N`] VAL_WORD_EQ)) in + let vwe64 = CONV_RULE NUM_REDUCE_CONV + (REWRITE_RULE[DIMINDEX_64] (INST_TYPE [`:64`,`:N`] VAL_WORD_EQ)) in + let cn32 = MATCH_MP vwe32 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 4294967296`) + curlen_bound)) in + let cn64 = MATCH_MP vwe64 (CONV_RULE NUM_REDUCE_CONV + (MATCH_MP (ARITH_RULE `n <= 248 ==> n < 18446744073709551616`) + curlen_bound)) in + let b256_norm = REWRITE_RULE[cn32; cn64] b256_ymm3 in + let val_eq = AP_TERM `val:int256->num` b256_norm in + let bytes32_eq = CONV_RULE(LAND_CONV( + REWRITE_CONV[READ_COMPONENT_COMPOSE; VAL_READ_BYTES256] THENC + REWRITE_CONV[GSYM READ_COMPONENT_COMPOSE])) val_eq in + let vpermd_raw = snd(List.find (fun (_,th) -> + is_eq(concl th) && + can (find_term (has_const "MOD")) (concl th) && + can (find_term (has_const "num_of_wordlist")) (concl th) && + can (find_term (has_var "newlen")) (concl th)) asl) in + let is_ymm3_read th = + is_eq(concl th) && type_of(rhs(concl th)) = `:int256` && + can (find_term (has_const "YMM3")) (lhs(concl th)) in + let vpermd_states = List.filter (fun v -> + let n = fst(dest_var v) in String.length n > 1 && n.[0] = 's' && + type_of v = `:x86state`) (frees(concl vpermd_raw)) in + let vp_state_name = fst(dest_var(List.hd vpermd_states)) in + let vpermd = try + let ymm_b32 = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var b256_state_name)) + (lhs(concl th))) asl) in + let ymm_vp = snd(List.find (fun (_,th) -> + is_ymm3_read th && + can (find_term (has_var vp_state_name)) + (lhs(concl th))) asl) in + let ymm_eq = TRANS ymm_b32 (SYM ymm_vp) in + let val_eq = AP_TERM `val:int256->num` ymm_eq in + REWRITE_RULE[GSYM val_eq] vpermd_raw + with _ -> vpermd_raw in + let newlen_bound = snd(List.find (fun (_,th) -> + try is_binary "<=" (concl th) && + (has_var "newlen" (lhand(concl th))) && + dest_small_numeral(rand(concl th)) = 8 + with _ -> false) asl) in + let bridge = MATCH_MP VPERMD_MEMORY_BRIDGE + (CONJ bytes32_eq (CONJ vpermd newlen_bound)) in + ASSUME_TAC bridge (asl,w) + with e -> Printf.printf "DBG: J2 PRE-ENSURES failed: %s\n%!" (Printexc.to_string e); failwith "J2 PRE-ENSURES") THEN + ENSURES_FINAL_STATE_TAC THEN + ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `N = i + 1` (fun th -> + REWRITE_TAC[th; ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + ARITH_RULE `24 * (i + 1) = 24 * i + 24`]) THENL + [UNDISCH_TAC `~(i + 1 < N)` THEN UNDISCH_TAC `i:num < N` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + NUM_OF_WORDLIST_APPEND] THEN + REWRITE_TAC[ADD_CLAUSES] THEN + ABBREV_TAC + `lr = LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + SUBGOAL_THEN `lr <= 8` ASSUME_TAC THENL + [EXPAND_TAC "lr" THEN REWRITE_TAC[REJ_SAMPLE] THEN + TRANS_TAC LE_TRANS `LENGTH(MAP (\x:24 word. word(val x MOD 2 EXP 23):int32) (SUB_LIST(8*i,8) (inlist:(24 word)list)))` THEN + REWRITE_TAC[LENGTH_FILTER; LENGTH_MAP; LENGTH_SUB_LIST] THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `248 < curlen + lr` ASSUME_TAC THENL + [(* N=i+1 substitution didn't propagate into disjunct hyp; do it + manually. Reduce LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*(i+1)) inlist)) + to curlen + lr via SUB_LIST_SPLIT + ABBREVs. *) + UNDISCH_TAC `248 < LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) (inlist:(24 word)list)))` THEN + UNDISCH_TAC `N = i + 1` THEN DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[ARITH_RULE `8 * (i + 1) = 8 * i + 8`; + SUB_LIST_SPLIT; REJ_SAMPLE_APPEND; LENGTH_APPEND; + ADD_CLAUSES] THEN + UNDISCH_TAC `REJ_SAMPLE(SUB_LIST(0, 8 * i) (inlist:(24 word)list)) = curlist` THEN + DISCH_THEN SUBST1_TAC THEN + UNDISCH_TAC `LENGTH(curlist:int32 list) = curlen` THEN + DISCH_THEN SUBST1_TAC THEN + UNDISCH_TAC `LENGTH(REJ_SAMPLE(SUB_LIST(8*i, 8) (inlist:(24 word)list))) = lr` THEN + DISCH_THEN SUBST1_TAC THEN + ARITH_TAC; + ALL_TAC] THEN + (* J2 has bridge `read(...4*newlen) s37 = num_of_wordlist (REJ_SAMPLE...)` + from PRE-ENSURES VPERMD_MEMORY_BRIDGE. After ABBREV_TAC of `lr = + LENGTH(REJ_SAMPLE(SUB_LIST(8*i,8) inlist))`, the asl already has + a `newlen = lr` (from prior `newlen = LENGTH(REJ_SAMPLE...)` being + rewritten by ABBREV). SUBST_ALL_TAC of this fact eliminates + `newlen` from the bridge so ASM_REWRITE in the MEM branch can + match `4 * lr` against the goal's `4 * curlen + 4 * lr`. *) + (fun (asl, w) -> + try + let newlen_eq_lr = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + (try fst(dest_var(lhs c)) = "newlen" with _ -> false) && + (try fst(dest_var(rhs c)) = "lr" with _ -> false)) asl) in + RULE_ASSUM_TAC (REWRITE_RULE [newlen_eq_lr]) (asl, w) + with _ -> ALL_TAC (asl, w)) THEN + (TRY (DISCARD_ASSUMPTIONS_TAC (fun th -> + try + let c = concl th in + let fvs = frees c in + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + not(is_eq c && + can (find_term (has_const "read")) c && + can (find_term (has_const "bytes")) c) && + not(can (find_term (has_var "cmp_mask")) c && + is_binary "<=" c) && + not(can (find_term (has_const "memaccess_inbounds")) c) && + (not (forall (fun v -> type_of v = `:num`) fvs) || + exists (fun v -> try let n = fst(dest_var v) in + n = "N" || n = "newlen" || n = "curlist" with _ -> false) fvs || + (try is_forall c with _ -> false)) + with _ -> false))) THEN + REPEAT CONJ_TAC THEN + (* PR1040 closing ladder: each TRY catches failure, so the right + tactic for each conjunct closes it independently. *) + TRY(FIRST_ASSUM ACCEPT_TAC) THEN + TRY(ASM_REWRITE_TAC[] THEN NO_TAC) THEN + TRY(ASM_ARITH_TAC) THEN + TRY(MATCH_MP_TAC(TAUT `p ==> (if p then (a:int64) else b) = a`) THEN + REWRITE_TAC[NOT_CLAUSES; DE_MORGAN_THM; + VAL_WORD_ZX_GEN; VAL_WORD_SUB_CASES; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `248 <= curlen + lr` ASSUME_TAC THENL + [UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(curlen + lr) - 248 < 18446744073709551616` + ASSUME_TAC THENL + [UNDISCH_TAC `curlen + lr < 18446744073709551616` THEN ARITH_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT; MOD_MOD_REFL] THEN + UNDISCH_TAC `248 < curlen + lr` THEN ARITH_TAC) THEN + TRY(REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `curlen < 4294967296 /\ lr < 4294967296 /\ + curlen < 18446744073709551616 /\ + lr < 18446744073709551616 /\ + curlen + lr < 4294967296 /\ + curlen + lr < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `curlen <= 248` THEN UNDISCH_TAC `lr <= 8` THEN + ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC) THEN + TRY(REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD_ADD; + VAL_WORD; DIMINDEX_32; DIMINDEX_64] THEN + CONV_TAC NUM_REDUCE_CONV THEN + UNDISCH_TAC `24 * i <= 808` THEN + SPEC_TAC(`24 * i`,`n:num`) THEN GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `n < 4294967296 /\ n < 18446744073709551616 /\ + n + 24 < 4294967296 /\ + n + 24 < 18446744073709551616` + STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `n <= 808` THEN ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[MOD_LT] THEN ARITH_TAC) THEN + TRY(REWRITE_TAC[DIMINDEX_32; + ARITH_RULE `4 * (curlen + lr) = 4 * curlen + 4 * lr`; + ARITH_RULE `32 * curlen = 8 * (4 * curlen)`] THEN + REWRITE_TAC[MEMORY_BYTES_SPLIT] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[EQ_ADD_LCANCEL; EQ_MULT_LCANCEL; EXP_EQ_0; ARITH_EQ] THEN + ASM_REWRITE_TAC[] THEN NO_TAC) THEN + (* DBG: log if mem branch failed and goal still has memory pattern. *) + (fun (asl,w) -> + let s = string_of_term w in + let m = String.length s in + let pat = "memory :> bytes (res" in + let pm = String.length pat in + let has_pat = + let rec check i = i + pm <= m && + (String.sub s i pm = pat || check (i+1)) in + m >= pm && check 0 in + if has_pat then + Printf.printf "DBG: J2 MEM BRANCH FAILED, residual=%s\n%!" + (if m < 300 then s else String.sub s 0 300 ^ "..."); + ALL_TAC (asl,w)) THEN + TRY (W(fun (_,w) -> + if (try let n = fst(dest_var(fst(dest_exists w))) in + n = "e_acc'" || n = "e_acc" || String.length n >= 5 && + String.sub n 0 5 = "e_acc" + with _ -> false) + then DISCHARGE_MEMSAFE_ASM_TAC + else NO_TAC)) THEN + TRY DISCHARGE_MEMSAFE_ASM_TAC]]; + + (* ================================================================= *) + (* Subgoal 3: Post-loop *) + (* *) + (* Entry: pc+181 with REJ_SAMPLE(SUB_LIST(0,8*N)) accumulated and *) + (* `?e_acc. read events s = APPEND e_acc e /\ memaccess_inbounds`. *) + (* Code structure: pc+181: CMP eax,256; JAE; pc+188: CMP ecx,837; *) + (* JA; pc+196..240: scalar coefficient loop; pc+242: RET. *) + (* ================================================================= *) + CONV_TAC(RATOR_CONV(LAND_CONV(TOP_DEPTH_CONV let_CONV))) THEN + MAP_EVERY ABBREV_TAC + [`outlist = REJ_SAMPLE (SUB_LIST (0,8 * N) inlist)`; + `outlen = LENGTH(outlist:int32 list)`] THEN + SUBGOAL_THEN + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) (inlist:(24 word)list))) = outlen` + ASSUME_TAC THENL + [UNDISCH_TAC `REJ_SAMPLE (SUB_LIST (0,8 * N) (inlist:(24 word)list)) = outlist` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN + UNDISCH_TAC `LENGTH (outlist:int32 list) = outlen` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]); + ALL_TAC] THEN + SUBGOAL_THEN + `24 * N <= 832 /\ + LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * (N - 1)) (inlist:(24 word)list))) <= 248` + STRIP_ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o SPEC `N - 1`) THEN + ANTS_TAC THENL [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(N - 1) + 1 = N` SUBST1_TAC THENL + [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[]; + ALL_TAC] THEN + SUBGOAL_THEN `outlen <= 256` ASSUME_TAC THENL + [MP_TAC(ISPECL [`inlist:(24 word)list`; `N - 1`; + `REJ_SAMPLE(SUB_LIST(0, 8*(N-1)) (inlist:(24 word)list))`; + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8*(N-1)) (inlist:(24 word)list)))`] + SIMD_ITERATION_BRIDGE) THEN + ANTS_TAC THENL + [REWRITE_TAC[] THEN + SUBGOAL_THEN `N - 1 + 1 = N` SUBST1_TAC THENL + [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + UNDISCH_TAC `LENGTH (inlist:(24 word)list) = 280` THEN + UNDISCH_TAC `24 * N <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `N - 1 + 1 = N` SUBST1_TAC THENL + [UNDISCH_TAC `~(N = 0)` THEN ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN + UNDISCH_TAC + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) (inlist:(24 word)list))) = outlen` THEN + UNDISCH_TAC + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N) (inlist:(24 word)list))) = + LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * (N - 1)) inlist)) + + LENGTH (REJ_SAMPLE (SUB_LIST (8 * (N - 1),8) inlist))` THEN + UNDISCH_TAC + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * (N - 1)) (inlist:(24 word)list))) <= 248` THEN + UNDISCH_TAC + `LENGTH (REJ_SAMPLE (SUB_LIST (8 * (N - 1),8) (inlist:(24 word)list))) <= 8` THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN + `?j. 256 <= LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N + j) (inlist:(24 word)list))) \/ + 837 < 24 * N + 3 * j` + MP_TAC THENL + [EXISTS_TAC `280:num` THEN DISJ2_TAC THEN + UNDISCH_TAC `24 * N <= 832` THEN ARITH_TAC; + GEN_REWRITE_TAC LAND_CONV [num_WOP]] THEN + DISCH_THEN(X_CHOOSE_THEN `K:num` (CONJUNCTS_THEN2 ASSUME_TAC MP_TAC)) THEN + DISCH_THEN(fun th -> + ASSUME_TAC(REWRITE_RULE[DE_MORGAN_THM; NOT_LE; NOT_LT] th)) THEN + ASM_CASES_TAC `K = 0` THENL + [(* K = 0: no scalar iterations. JAE at pc+181 fires to pc+242. *) + SUBGOAL_THEN `outlen = 256` ASSUME_TAC THENL + [RULE_ASSUM_TAC(REWRITE_RULE[ASSUME `K = 0`; ADD_CLAUSES; MULT_CLAUSES]) THEN + UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N) (inlist:(24 word)list))) \/ + 837 < 24 * N` THEN + UNDISCH_TAC + `LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N) (inlist:(24 word)list))) = outlen` THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN + UNDISCH_TAC `outlen <= 256` THEN + UNDISCH_TAC `24 * N <= 832` THEN ARITH_TAC; + ALL_TAC] THEN + ENSURES_INIT_TAC "s0" THEN + STRIP_EXISTS_ASSUM_TAC THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_EXEC [40;41] THEN + RULE_ASSUM_TAC(REWRITE_RULE[ASSUME `outlen = 256`]) THEN + (* RIP s41 = pc+242 already resolved by VSTEPS (since outlen=256 makes + JAE fire statically); no COND elimination needed. *) + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_REWRITE_TAC[] THEN REPEAT CONJ_TAC THEN + TRY (FIRST_ASSUM ACCEPT_TAC) THEN + TRY (REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC) THEN + TRY (W(fun (asl,w) -> + try + let _, body = dest_exists w in + if can (find_term (fun t -> + try fst(dest_const t) = "memaccess_inbounds" with _ -> false)) body + then + try + let _, ainbds_th = List.find (fun (_,th) -> + let c = concl th in + try fst(dest_const(rator(rator(rator c)))) = "memaccess_inbounds" + with _ -> false) asl in + let e_var = rand(rator(rator(concl ainbds_th))) in + (DISCHARGE_MEMSAFE_ASM_TAC ORELSE + (EXISTS_TAC e_var THEN ASM_REWRITE_TAC[])) + with _ -> DISCHARGE_MEMSAFE_ASM_TAC + else NO_TAC + with _ -> NO_TAC)); + (* K > 0: scalar tail runs K iterations. Use ENSURES_WHILE_UP2_TAC with + events-tracking invariant, body discharged via SCALAR_BODY_LEMMA_MEMSAFE. *) + ENSURES_WHILE_UP2_TAC `K:num` `pc + 181` `pc + 242` + `\i s. read RSP s = stackpointer /\ + read (memory :> bytes (buf,840)) s = num_of_wordlist inlist /\ + read (memory :> bytes (table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read RDI s = res /\ read RSI s = buf /\ read RDX s = table /\ + read YMM0 s = + word 115366376096492355175489748997433888275274855593258845241081954797768348401920 /\ + read YMM1 s = + word 226156397384342666605459106258636701594091082888230722833791023177481060351 /\ + read YMM2 s = + word 225935595421087293402315996791205668696012104344015382954355885915737415681 /\ + (let outlist_i = REJ_SAMPLE(SUB_LIST(0, 8 * N + i) (inlist:(24 word)list)) in + let outlen_i = LENGTH outlist_i in + read RAX s = word outlen_i /\ + read RCX s = word(24 * N + 3 * i) /\ + read(memory :> bytes(res, 4 * outlen_i)) s = num_of_wordlist outlist_i) /\ + (exists e_acc. read events s = APPEND e_acc e /\ + memaccess_inbounds e_acc + [buf,840; table,2048] [res,1024])` THEN + ASM_REWRITE_TAC[] THEN REPEAT CONJ_TAC THENL + [(* Init: invariant @ 0 *) + ENSURES_INIT_TAC "s0" THEN + STRIP_EXISTS_ASSUM_TAC THEN + ENSURES_FINAL_STATE_TAC THEN + ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[ADD_CLAUSES; MULT_CLAUSES] THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THEN + TRY (FIRST_ASSUM ACCEPT_TAC) THEN + TRY (REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC) THEN + TRY (W(fun (asl,w) -> + try + let _, body = dest_exists w in + if can (find_term (fun t -> + try fst(dest_const t) = "memaccess_inbounds" with _ -> false)) body + then + (DISCHARGE_MEMSAFE_ASM_TAC ORELSE + (try + let _, ainbds_th = List.find (fun (_,th) -> + let c = concl th in + try fst(dest_const(rator(rator(rator c)))) = "memaccess_inbounds" + with _ -> false) asl in + let e_var = rand(rator(rator(concl ainbds_th))) in + EXISTS_TAC e_var THEN ASM_REWRITE_TAC[] + with _ -> NO_TAC)) + else NO_TAC + with _ -> NO_TAC)); + (* Body: invariant @ i -> invariant @ (i+1). Use SCALAR_BODY_LEMMA_MEMSAFE. *) + X_GEN_TAC `i:num` THEN STRIP_TAC THEN + REWRITE_TAC[GSYM SOME_FLAGS] THEN + MATCH_MP_TAC SCALAR_BODY_LEMMA_MEMSAFE THEN + ASM_REWRITE_TAC[NONOVERLAPPING_CLAUSES] THEN + CONJ_TAC THENL + [X_GEN_TAC `j:num` THEN DISCH_TAC THEN + FIRST_X_ASSUM(MP_TAC o SPEC `j:num` o check (is_forall o concl)) THEN + ASM_REWRITE_TAC[] THEN ARITH_TAC; + FIRST_X_ASSUM(MATCH_ACCEPT_TAC o check (fun th -> + let c = concl th in is_disj c && + can (find_term ((=) `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))`)) c))]; + (* Post: invariant @ K -> postcondition *) + ENSURES_INIT_TAC "s0" THEN + STRIP_EXISTS_ASSUM_TAC THEN + RULE_ASSUM_TAC(CONV_RULE(TOP_DEPTH_CONV let_CONV)) THEN + FIRST_X_ASSUM(fun th -> + let c = concl th in + if is_conj c && (try can (find_term ((=) `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))`)) c with _ -> false) + then STRIP_ASSUME_TAC th else failwith "not inv") THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + FIRST_X_ASSUM(DISJ_CASES_TAC o check (is_disj o concl)) THENL + [(* count-exit: 256 <= outlen_K *) + SUBGOAL_THEN + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256` + ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o SPEC `K - 1`) THEN + ANTS_TAC THENL [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + K - 1`] REJ_SAMPLE_STEP_LE) THEN + SUBGOAL_THEN `(8 * N + K - 1) + 1 = 8 * N + K` SUBST1_TAC THENL + [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) = + REJ_SAMPLE (SUB_LIST (0, 8 * N + K) inlist)` + ASSUME_TAC THENL + [MATCH_MP_TAC REJ_SAMPLE_PREFIX_256 THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + RULE_ASSUM_TAC(REWRITE_RULE[ASSUME + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256`]) THEN + ASM_REWRITE_TAC[] THEN REPEAT CONJ_TAC THEN + TRY (FIRST_ASSUM ACCEPT_TAC) THEN + TRY (REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC) THEN + TRY (W(fun (asl,w) -> + try + let _, body = dest_exists w in + if can (find_term (fun t -> + try fst(dest_const t) = "memaccess_inbounds" with _ -> false)) body + then + (DISCHARGE_MEMSAFE_ASM_TAC ORELSE + (try + let _, ainbds_th = List.find (fun (_,th) -> + let c = concl th in + try fst(dest_const(rator(rator(rator c)))) = "memaccess_inbounds" + with _ -> false) asl in + let e_var = rand(rator(rator(concl ainbds_th))) in + EXISTS_TAC e_var THEN ASM_REWRITE_TAC[] + with _ -> NO_TAC)) + else NO_TAC + with _ -> NO_TAC)); + (* offset-exit *) + ASM_CASES_TAC + `256 <= LENGTH(REJ_SAMPLE(SUB_LIST(0, 8 * N + K) (inlist:(24 word)list)))` + THENL + [SUBGOAL_THEN + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256` + ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o SPEC `K - 1`) THEN + ANTS_TAC THENL [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(ISPECL [`inlist:(24 word)list`; `8 * N + K - 1`] REJ_SAMPLE_STEP_LE) THEN + SUBGOAL_THEN `(8 * N + K - 1) + 1 = 8 * N + K` SUBST1_TAC THENL + [UNDISCH_TAC `~(K = 0)` THEN ARITH_TAC; ALL_TAC] THEN + UNDISCH_TAC + `256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list)))` THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) = + REJ_SAMPLE (SUB_LIST (0, 8 * N + K) inlist)` + ASSUME_TAC THENL + [MATCH_MP_TAC REJ_SAMPLE_PREFIX_256 THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + RULE_ASSUM_TAC(REWRITE_RULE[ASSUME + `LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))) = 256`]) THEN + ASM_REWRITE_TAC[] THEN REPEAT CONJ_TAC THEN + TRY (FIRST_ASSUM ACCEPT_TAC) THEN + TRY (REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC) THEN + TRY (W(fun (asl,w) -> + try + let _, body = dest_exists w in + if can (find_term (fun t -> + try fst(dest_const t) = "memaccess_inbounds" with _ -> false)) body + then + (DISCHARGE_MEMSAFE_ASM_TAC ORELSE + (try + let _, ainbds_th = List.find (fun (_,th) -> + let c = concl th in + try fst(dest_const(rator(rator(rator c)))) = "memaccess_inbounds" + with _ -> false) asl in + let e_var = rand(rator(rator(concl ainbds_th))) in + EXISTS_TAC e_var THEN ASM_REWRITE_TAC[] + with _ -> NO_TAC)) + else NO_TAC + with _ -> NO_TAC)); + SUBGOAL_THEN `SUB_LIST (0, 8 * N + K) (inlist:(24 word)list) = inlist` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN + UNDISCH_TAC `LENGTH (inlist:(24 word)list) = 280` THEN + UNDISCH_TAC `837 < 24 * N + 3 * K` THEN ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `LENGTH (REJ_SAMPLE (inlist:(24 word)list)) <= 256` + ASSUME_TAC THENL + [UNDISCH_TAC + `~(256 <= LENGTH (REJ_SAMPLE (SUB_LIST (0,8 * N + K) (inlist:(24 word)list))))` THEN + SUBGOAL_THEN `SUB_LIST (0, 8 * N + K) (inlist:(24 word)list) = inlist` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN + UNDISCH_TAC `LENGTH (inlist:(24 word)list) = 280` THEN + UNDISCH_TAC `837 < 24 * N + 3 * K` THEN ARITH_TAC; + ALL_TAC] THEN + ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)) = REJ_SAMPLE inlist` + SUBST1_TAC THENL + [MATCH_MP_TAC SUB_LIST_REFL THEN ASM_REWRITE_TAC[]; + ALL_TAC] THEN + REWRITE_TAC[] THEN + (fun (asl, w) -> + try + let has_const name t = try fst(dest_const t) = name with _ -> false in + let has_var name t = try fst(dest_var t) = name with _ -> false in + let mem_hyp = snd(List.find (fun (_, th) -> + let c = concl th in + is_eq c && + can (find_term (has_const "REJ_SAMPLE")) c && + can (find_term (has_const "bytes")) c && + can (find_term (has_const "memory")) c && + can (find_term (has_var "res")) c) asl) in + let len280 = snd(List.find (fun (_, th) -> + concl th = `LENGTH (inlist:(24 word)list) = 280`) asl) in + let off837 = snd(List.find (fun (_, th) -> + concl th = `837 < 24 * N + 3 * K`) asl) in + let bound_th = MP (MP + (ARITH_RULE `LENGTH (inlist:(24 word)list) = 280 + ==> 837 < 24 * N + 3 * K + ==> LENGTH inlist <= 8 * N + K`) len280) off837 in + let sub_eq = MATCH_MP + (ISPECL [`inlist:(24 word)list`; `8 * N + K`] SUB_LIST_REFL) + bound_th in + let mem_hyp' = REWRITE_RULE[sub_eq] mem_hyp in + Printf.printf "DBG: K>0 oe-only mem_hyp' = %s\n%!" (string_of_term (concl mem_hyp')); + (REPEAT CONJ_TAC THEN + TRY (FIRST_ASSUM ACCEPT_TAC) THEN + TRY (ACCEPT_TAC mem_hyp') THEN + TRY (REWRITE_TAC[SOME_FLAGS] THEN MONOTONE_MAYCHANGE_TAC) THEN + TRY (W(fun (asl,w) -> + try + let _, body = dest_exists w in + if can (find_term (fun t -> + try fst(dest_const t) = "memaccess_inbounds" with _ -> false)) body + then + (DISCHARGE_MEMSAFE_ASM_TAC ORELSE + (try + let _, ainbds_th = List.find (fun (_,th) -> + let c = concl th in + try fst(dest_const(rator(rator(rator c)))) = "memaccess_inbounds" + with _ -> false) asl in + let e_var = rand(rator(rator(concl ainbds_th))) in + EXISTS_TAC e_var THEN ASM_REWRITE_TAC[] + with _ -> NO_TAC)) + else NO_TAC + with _ -> NO_TAC))) (asl, w) + with e -> Printf.printf "DBG: memory finalize failed: %s\n%!" (Printexc.to_string e); failwith "memory finalize failed")]]]]]);; + + +(* ------------------------------------------------------------------------- *) +(* Subroutine variants for memory safety. *) +(* ------------------------------------------------------------------------- *) + +let MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_SAFE = time prove + (`!res buf table (inlist:(24 word)list) e pc stackpointer returnaddress. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_tmc) (res, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_tmc) (buf, 840) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_tmc) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (res, 1024) /\ + nonoverlapping (stackpointer, 8) (buf, 840) /\ + nonoverlapping (stackpointer, 8) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (word pc, LENGTH mldsa_rej_uniform_tmc) + ==> ensures x86 + (\s. bytes_loaded s (word pc) mldsa_rej_uniform_tmc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read events s = e) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (exists e2. + read events s = APPEND e2 e /\ + memaccess_inbounds e2 + [buf,840; table,2048; stackpointer,8] + [res,1024])) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + X86_PROMOTE_RETURN_NOSTACK_TAC mldsa_rej_uniform_tmc + MLDSA_REJ_UNIFORM_MEMSAFE THEN + DISCHARGE_MEMSAFE_TAC);; + +let MLDSA_REJ_UNIFORM_SUBROUTINE_SAFE = time prove + (`!res buf table (inlist:(24 word)list) e pc stackpointer returnaddress. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_mc) (res, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_mc) (buf, 840) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_mc) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (res, 1024) /\ + nonoverlapping (stackpointer, 8) (buf, 840) /\ + nonoverlapping (stackpointer, 8) (table, 2048) /\ + nonoverlapping (stackpointer, 8) (word pc, LENGTH mldsa_rej_uniform_mc) + ==> ensures x86 + (\s. bytes_loaded s (word pc) mldsa_rej_uniform_mc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list) /\ + read events s = e) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (exists e2. + read events s = APPEND e2 e /\ + memaccess_inbounds e2 + [buf,840; table,2048; stackpointer,8] + [res,1024])) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(res,1024)])`, + MATCH_ACCEPT_TAC(ADD_IBT_RULE MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_SAFE));; + +(* ========================================================================= *) +(* Windows ABI variants. *) +(* ========================================================================= *) + +let mldsa_rej_uniform_windows_mc = define_from_elf "mldsa_rej_uniform_windows_mc" + "x86/mldsa/mldsa_rej_uniform.obj";; + +let mldsa_rej_uniform_windows_tmc = + define_trimmed "mldsa_rej_uniform_windows_tmc" mldsa_rej_uniform_windows_mc;; + +let MLDSA_REJ_UNIFORM_WINDOWS_TMC_EXEC = + X86_MK_EXEC_RULE mldsa_rej_uniform_windows_tmc;; + +let MLDSA_REJ_UNIFORM_NOIBT_WINDOWS_SUBROUTINE_CORRECT = prove + (`!res buf table (inlist:(24 word)list) pc stackpointer returnaddress. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_windows_tmc) (res, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_windows_tmc) (buf, 840) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_windows_tmc) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) (res, 1024) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) (buf, 840) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) (table, 2048) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) + (word pc, LENGTH mldsa_rej_uniform_windows_tmc) + ==> ensures x86 + (\s. bytes_loaded s (word pc) mldsa_rej_uniform_windows_tmc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + WINDOWS_C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (let outlist = SUB_LIST(0,256) (REJ_SAMPLE inlist) in + let outlen = LENGTH outlist in + WINDOWS_C_RETURN s = word outlen /\ + read(memory :> bytes(res,4 * outlen)) s = + num_of_wordlist outlist)) + (MAYCHANGE [RSP] ,, WINDOWS_MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(res,1024)] ,, + MAYCHANGE [memory :> bytes(word_sub stackpointer (word 176),176)])`, + REPLICATE_TAC 5 GEN_TAC THEN + WORD_FORALL_OFFSET_TAC 176 THEN REPEAT GEN_TAC THEN + + REWRITE_TAC[fst MLDSA_REJ_UNIFORM_WINDOWS_TMC_EXEC] THEN + REPEAT STRIP_TAC THEN + REWRITE_TAC[WINDOWS_C_ARGUMENTS; WINDOWS_C_RETURN] THEN + REWRITE_TAC[WINDOWS_MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI] THEN + + ENSURES_PRESERVED_TAC "rdi_init" `RDI` THEN + ENSURES_PRESERVED_TAC "rsi_init" `RSI` THEN + ENSURES_PRESERVED_TAC "init_xmm6" `ZMM6 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm7" `ZMM7 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm8" `ZMM8 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm9" `ZMM9 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm10" `ZMM10 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm11" `ZMM11 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm12" `ZMM12 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm13" `ZMM13 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm14" `ZMM14 :> bottomhalf :> bottomhalf` THEN + ENSURES_PRESERVED_TAC "init_xmm15" `ZMM15 :> bottomhalf :> bottomhalf` THEN + + REWRITE_TAC[READ_ZMM_BOTTOM_QUARTER'] THEN + REWRITE_TAC(map GSYM + [YMM6;YMM7;YMM8;YMM9;YMM10;YMM11;YMM12;YMM13;YMM14;YMM15]) THEN + + GHOST_INTRO_TAC `init_ymm6:int256` `read YMM6` THEN + GHOST_INTRO_TAC `init_ymm7:int256` `read YMM7` THEN + GHOST_INTRO_TAC `init_ymm8:int256` `read YMM8` THEN + GHOST_INTRO_TAC `init_ymm9:int256` `read YMM9` THEN + GHOST_INTRO_TAC `init_ymm10:int256` `read YMM10` THEN + GHOST_INTRO_TAC `init_ymm11:int256` `read YMM11` THEN + GHOST_INTRO_TAC `init_ymm12:int256` `read YMM12` THEN + GHOST_INTRO_TAC `init_ymm13:int256` `read YMM13` THEN + GHOST_INTRO_TAC `init_ymm14:int256` `read YMM14` THEN + GHOST_INTRO_TAC `init_ymm15:int256` `read YMM15` THEN + + GLOBALIZE_PRECONDITION_TAC THEN + (* Substitute init_xmmN → word_zx init_ymmN in the goal, using the + assumptions from ENSURES_PRESERVED_TAC. Unlike mldsa_reduce's broad + `REPEAT(FIRST_X_ASSUM(SUBST1_TAC o SYM))`, this keeps `LENGTH inlist = 280` + as an assumption (we need it later for the linux CORRECT precondition). *) + MAP_EVERY (fun n -> + UNDISCH_THEN + (mk_eq(mk_comb(`word_zx:int256->int128`, + mk_var("init_ymm"^string_of_int n,`:int256`)), + mk_var("init_xmm"^string_of_int n,`:int128`))) + (SUBST1_TAC o SYM)) + [6;7;8;9;10;11;12;13;14;15] THEN + + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC MLDSA_REJ_UNIFORM_WINDOWS_TMC_EXEC (1--16) THEN + + MP_TAC(SPECL + [`res:int64`; `buf:int64`; `table:int64`; + `inlist:(24 word)list`; `pc + 91`] + MLDSA_REJ_UNIFORM_CORRECT) THEN + ASM_REWRITE_TAC[C_ARGUMENTS; SOME_FLAGS] THEN + ANTS_TAC THENL [REPEAT CONJ_TAC THEN NONOVERLAPPING_TAC; ALL_TAC] THEN + + X86_BIGSTEP_TAC MLDSA_REJ_UNIFORM_WINDOWS_TMC_EXEC "s17" THENL + [FIRST_ASSUM(MATCH_ACCEPT_TAC o MATCH_MP + (BYTES_LOADED_SUBPROGRAM_RULE mldsa_rej_uniform_windows_tmc + (REWRITE_RULE[BUTLAST_CLAUSES] + (AP_TERM `BUTLAST:byte list->byte list` mldsa_rej_uniform_tmc)) + 91)); + RULE_ASSUM_TAC(CONV_RULE(TRY_CONV RIP_PLUS_CONV))] THEN + + RULE_ASSUM_TAC(CONV_RULE(DEPTH_CONV let_CONV)) THEN + RULE_ASSUM_TAC(REWRITE_RULE[C_RETURN]) THEN + ABBREV_TAC + `outlen = LENGTH (SUB_LIST (0,256) (REJ_SAMPLE (inlist:(24 word)list)))` THEN + FIRST_X_ASSUM(CONJUNCTS_THEN ASSUME_TAC) THEN + + MAP_EVERY ABBREV_TAC + [`ymm6_epilog = read YMM6 s17`; + `ymm7_epilog = read YMM7 s17`; + `ymm8_epilog = read YMM8 s17`; + `ymm9_epilog = read YMM9 s17`; + `ymm10_epilog = read YMM10 s17`; + `ymm11_epilog = read YMM11 s17`; + `ymm12_epilog = read YMM12 s17`; + `ymm13_epilog = read YMM13 s17`; + `ymm14_epilog = read YMM14 s17`; + `ymm15_epilog = read YMM15 s17`] THEN + + X86_STEPS_TAC MLDSA_REJ_UNIFORM_WINDOWS_TMC_EXEC (18--31) THEN + + RULE_ASSUM_TAC(REWRITE_RULE[MAYCHANGE_ZMM_QUARTER]) THEN + RULE_ASSUM_TAC(REWRITE_RULE[MAYCHANGE_YMM_SSE_QUARTER]) THEN + + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THEN CONV_TAC WORD_BLAST);; + +let MLDSA_REJ_UNIFORM_WINDOWS_SUBROUTINE_CORRECT = prove + (`!res buf table (inlist:(24 word)list) pc stackpointer returnaddress. + LENGTH inlist = 280 /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_windows_mc) (res, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_windows_mc) (buf, 840) /\ + nonoverlapping (word pc, LENGTH mldsa_rej_uniform_windows_mc) (table, 2048) /\ + nonoverlapping (res, 1024) (buf, 840) /\ + nonoverlapping (res, 1024) (table, 2048) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) (res, 1024) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) (buf, 840) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) (table, 2048) /\ + nonoverlapping (word_sub stackpointer (word 176), 184) + (word pc, LENGTH mldsa_rej_uniform_windows_mc) + ==> ensures x86 + (\s. bytes_loaded s (word pc) mldsa_rej_uniform_windows_mc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + WINDOWS_C_ARGUMENTS [res; buf; table] s /\ + read(memory :> bytes(buf,840)) s = num_of_wordlist inlist /\ + read(memory :> bytes(table,2048)) s = + num_of_wordlist(mldsa_rej_uniform_table:byte list)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (let outlist = SUB_LIST(0,256) (REJ_SAMPLE inlist) in + let outlen = LENGTH outlist in + WINDOWS_C_RETURN s = word outlen /\ + read(memory :> bytes(res,4 * outlen)) s = + num_of_wordlist outlist)) + (MAYCHANGE [RSP] ,, WINDOWS_MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(res,1024)] ,, + MAYCHANGE [memory :> bytes(word_sub stackpointer (word 176),176)])`, + MATCH_ACCEPT_TAC(ADD_IBT_RULE MLDSA_REJ_UNIFORM_NOIBT_WINDOWS_SUBROUTINE_CORRECT));; + diff --git a/x86/proofs/mldsa_rej_uniform_table.ml b/x86/proofs/mldsa_rej_uniform_table.ml new file mode 100644 index 000000000..818a324f7 --- /dev/null +++ b/x86/proofs/mldsa_rej_uniform_table.ml @@ -0,0 +1,268 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* Lookup table for ML-DSA rejection uniform sampling. *) +(* Each entry is 8 bytes: permutation indices for VPERMD. *) + +let mldsa_rej_uniform_table = (REWRITE_RULE[MAP] o define) + `mldsa_rej_uniform_table:byte list = MAP word [ + 0; 0; 0; 0; 0; 0; 0; 0; + 0; 0; 0; 0; 0; 0; 0; 0; + 1; 0; 0; 0; 0; 0; 0; 0; + 0; 1; 0; 0; 0; 0; 0; 0; + 2; 0; 0; 0; 0; 0; 0; 0; + 0; 2; 0; 0; 0; 0; 0; 0; + 1; 2; 0; 0; 0; 0; 0; 0; + 0; 1; 2; 0; 0; 0; 0; 0; + 3; 0; 0; 0; 0; 0; 0; 0; + 0; 3; 0; 0; 0; 0; 0; 0; + 1; 3; 0; 0; 0; 0; 0; 0; + 0; 1; 3; 0; 0; 0; 0; 0; + 2; 3; 0; 0; 0; 0; 0; 0; + 0; 2; 3; 0; 0; 0; 0; 0; + 1; 2; 3; 0; 0; 0; 0; 0; + 0; 1; 2; 3; 0; 0; 0; 0; + 4; 0; 0; 0; 0; 0; 0; 0; + 0; 4; 0; 0; 0; 0; 0; 0; + 1; 4; 0; 0; 0; 0; 0; 0; + 0; 1; 4; 0; 0; 0; 0; 0; + 2; 4; 0; 0; 0; 0; 0; 0; + 0; 2; 4; 0; 0; 0; 0; 0; + 1; 2; 4; 0; 0; 0; 0; 0; + 0; 1; 2; 4; 0; 0; 0; 0; + 3; 4; 0; 0; 0; 0; 0; 0; + 0; 3; 4; 0; 0; 0; 0; 0; + 1; 3; 4; 0; 0; 0; 0; 0; + 0; 1; 3; 4; 0; 0; 0; 0; + 2; 3; 4; 0; 0; 0; 0; 0; + 0; 2; 3; 4; 0; 0; 0; 0; + 1; 2; 3; 4; 0; 0; 0; 0; + 0; 1; 2; 3; 4; 0; 0; 0; + 5; 0; 0; 0; 0; 0; 0; 0; + 0; 5; 0; 0; 0; 0; 0; 0; + 1; 5; 0; 0; 0; 0; 0; 0; + 0; 1; 5; 0; 0; 0; 0; 0; + 2; 5; 0; 0; 0; 0; 0; 0; + 0; 2; 5; 0; 0; 0; 0; 0; + 1; 2; 5; 0; 0; 0; 0; 0; + 0; 1; 2; 5; 0; 0; 0; 0; + 3; 5; 0; 0; 0; 0; 0; 0; + 0; 3; 5; 0; 0; 0; 0; 0; + 1; 3; 5; 0; 0; 0; 0; 0; + 0; 1; 3; 5; 0; 0; 0; 0; + 2; 3; 5; 0; 0; 0; 0; 0; + 0; 2; 3; 5; 0; 0; 0; 0; + 1; 2; 3; 5; 0; 0; 0; 0; + 0; 1; 2; 3; 5; 0; 0; 0; + 4; 5; 0; 0; 0; 0; 0; 0; + 0; 4; 5; 0; 0; 0; 0; 0; + 1; 4; 5; 0; 0; 0; 0; 0; + 0; 1; 4; 5; 0; 0; 0; 0; + 2; 4; 5; 0; 0; 0; 0; 0; + 0; 2; 4; 5; 0; 0; 0; 0; + 1; 2; 4; 5; 0; 0; 0; 0; + 0; 1; 2; 4; 5; 0; 0; 0; + 3; 4; 5; 0; 0; 0; 0; 0; + 0; 3; 4; 5; 0; 0; 0; 0; + 1; 3; 4; 5; 0; 0; 0; 0; + 0; 1; 3; 4; 5; 0; 0; 0; + 2; 3; 4; 5; 0; 0; 0; 0; + 0; 2; 3; 4; 5; 0; 0; 0; + 1; 2; 3; 4; 5; 0; 0; 0; + 0; 1; 2; 3; 4; 5; 0; 0; + 6; 0; 0; 0; 0; 0; 0; 0; + 0; 6; 0; 0; 0; 0; 0; 0; + 1; 6; 0; 0; 0; 0; 0; 0; + 0; 1; 6; 0; 0; 0; 0; 0; + 2; 6; 0; 0; 0; 0; 0; 0; + 0; 2; 6; 0; 0; 0; 0; 0; + 1; 2; 6; 0; 0; 0; 0; 0; + 0; 1; 2; 6; 0; 0; 0; 0; + 3; 6; 0; 0; 0; 0; 0; 0; + 0; 3; 6; 0; 0; 0; 0; 0; + 1; 3; 6; 0; 0; 0; 0; 0; + 0; 1; 3; 6; 0; 0; 0; 0; + 2; 3; 6; 0; 0; 0; 0; 0; + 0; 2; 3; 6; 0; 0; 0; 0; + 1; 2; 3; 6; 0; 0; 0; 0; + 0; 1; 2; 3; 6; 0; 0; 0; + 4; 6; 0; 0; 0; 0; 0; 0; + 0; 4; 6; 0; 0; 0; 0; 0; + 1; 4; 6; 0; 0; 0; 0; 0; + 0; 1; 4; 6; 0; 0; 0; 0; + 2; 4; 6; 0; 0; 0; 0; 0; + 0; 2; 4; 6; 0; 0; 0; 0; + 1; 2; 4; 6; 0; 0; 0; 0; + 0; 1; 2; 4; 6; 0; 0; 0; + 3; 4; 6; 0; 0; 0; 0; 0; + 0; 3; 4; 6; 0; 0; 0; 0; + 1; 3; 4; 6; 0; 0; 0; 0; + 0; 1; 3; 4; 6; 0; 0; 0; + 2; 3; 4; 6; 0; 0; 0; 0; + 0; 2; 3; 4; 6; 0; 0; 0; + 1; 2; 3; 4; 6; 0; 0; 0; + 0; 1; 2; 3; 4; 6; 0; 0; + 5; 6; 0; 0; 0; 0; 0; 0; + 0; 5; 6; 0; 0; 0; 0; 0; + 1; 5; 6; 0; 0; 0; 0; 0; + 0; 1; 5; 6; 0; 0; 0; 0; + 2; 5; 6; 0; 0; 0; 0; 0; + 0; 2; 5; 6; 0; 0; 0; 0; + 1; 2; 5; 6; 0; 0; 0; 0; + 0; 1; 2; 5; 6; 0; 0; 0; + 3; 5; 6; 0; 0; 0; 0; 0; + 0; 3; 5; 6; 0; 0; 0; 0; + 1; 3; 5; 6; 0; 0; 0; 0; + 0; 1; 3; 5; 6; 0; 0; 0; + 2; 3; 5; 6; 0; 0; 0; 0; + 0; 2; 3; 5; 6; 0; 0; 0; + 1; 2; 3; 5; 6; 0; 0; 0; + 0; 1; 2; 3; 5; 6; 0; 0; + 4; 5; 6; 0; 0; 0; 0; 0; + 0; 4; 5; 6; 0; 0; 0; 0; + 1; 4; 5; 6; 0; 0; 0; 0; + 0; 1; 4; 5; 6; 0; 0; 0; + 2; 4; 5; 6; 0; 0; 0; 0; + 0; 2; 4; 5; 6; 0; 0; 0; + 1; 2; 4; 5; 6; 0; 0; 0; + 0; 1; 2; 4; 5; 6; 0; 0; + 3; 4; 5; 6; 0; 0; 0; 0; + 0; 3; 4; 5; 6; 0; 0; 0; + 1; 3; 4; 5; 6; 0; 0; 0; + 0; 1; 3; 4; 5; 6; 0; 0; + 2; 3; 4; 5; 6; 0; 0; 0; + 0; 2; 3; 4; 5; 6; 0; 0; + 1; 2; 3; 4; 5; 6; 0; 0; + 0; 1; 2; 3; 4; 5; 6; 0; + 7; 0; 0; 0; 0; 0; 0; 0; + 0; 7; 0; 0; 0; 0; 0; 0; + 1; 7; 0; 0; 0; 0; 0; 0; + 0; 1; 7; 0; 0; 0; 0; 0; + 2; 7; 0; 0; 0; 0; 0; 0; + 0; 2; 7; 0; 0; 0; 0; 0; + 1; 2; 7; 0; 0; 0; 0; 0; + 0; 1; 2; 7; 0; 0; 0; 0; + 3; 7; 0; 0; 0; 0; 0; 0; + 0; 3; 7; 0; 0; 0; 0; 0; + 1; 3; 7; 0; 0; 0; 0; 0; + 0; 1; 3; 7; 0; 0; 0; 0; + 2; 3; 7; 0; 0; 0; 0; 0; + 0; 2; 3; 7; 0; 0; 0; 0; + 1; 2; 3; 7; 0; 0; 0; 0; + 0; 1; 2; 3; 7; 0; 0; 0; + 4; 7; 0; 0; 0; 0; 0; 0; + 0; 4; 7; 0; 0; 0; 0; 0; + 1; 4; 7; 0; 0; 0; 0; 0; + 0; 1; 4; 7; 0; 0; 0; 0; + 2; 4; 7; 0; 0; 0; 0; 0; + 0; 2; 4; 7; 0; 0; 0; 0; + 1; 2; 4; 7; 0; 0; 0; 0; + 0; 1; 2; 4; 7; 0; 0; 0; + 3; 4; 7; 0; 0; 0; 0; 0; + 0; 3; 4; 7; 0; 0; 0; 0; + 1; 3; 4; 7; 0; 0; 0; 0; + 0; 1; 3; 4; 7; 0; 0; 0; + 2; 3; 4; 7; 0; 0; 0; 0; + 0; 2; 3; 4; 7; 0; 0; 0; + 1; 2; 3; 4; 7; 0; 0; 0; + 0; 1; 2; 3; 4; 7; 0; 0; + 5; 7; 0; 0; 0; 0; 0; 0; + 0; 5; 7; 0; 0; 0; 0; 0; + 1; 5; 7; 0; 0; 0; 0; 0; + 0; 1; 5; 7; 0; 0; 0; 0; + 2; 5; 7; 0; 0; 0; 0; 0; + 0; 2; 5; 7; 0; 0; 0; 0; + 1; 2; 5; 7; 0; 0; 0; 0; + 0; 1; 2; 5; 7; 0; 0; 0; + 3; 5; 7; 0; 0; 0; 0; 0; + 0; 3; 5; 7; 0; 0; 0; 0; + 1; 3; 5; 7; 0; 0; 0; 0; + 0; 1; 3; 5; 7; 0; 0; 0; + 2; 3; 5; 7; 0; 0; 0; 0; + 0; 2; 3; 5; 7; 0; 0; 0; + 1; 2; 3; 5; 7; 0; 0; 0; + 0; 1; 2; 3; 5; 7; 0; 0; + 4; 5; 7; 0; 0; 0; 0; 0; + 0; 4; 5; 7; 0; 0; 0; 0; + 1; 4; 5; 7; 0; 0; 0; 0; + 0; 1; 4; 5; 7; 0; 0; 0; + 2; 4; 5; 7; 0; 0; 0; 0; + 0; 2; 4; 5; 7; 0; 0; 0; + 1; 2; 4; 5; 7; 0; 0; 0; + 0; 1; 2; 4; 5; 7; 0; 0; + 3; 4; 5; 7; 0; 0; 0; 0; + 0; 3; 4; 5; 7; 0; 0; 0; + 1; 3; 4; 5; 7; 0; 0; 0; + 0; 1; 3; 4; 5; 7; 0; 0; + 2; 3; 4; 5; 7; 0; 0; 0; + 0; 2; 3; 4; 5; 7; 0; 0; + 1; 2; 3; 4; 5; 7; 0; 0; + 0; 1; 2; 3; 4; 5; 7; 0; + 6; 7; 0; 0; 0; 0; 0; 0; + 0; 6; 7; 0; 0; 0; 0; 0; + 1; 6; 7; 0; 0; 0; 0; 0; + 0; 1; 6; 7; 0; 0; 0; 0; + 2; 6; 7; 0; 0; 0; 0; 0; + 0; 2; 6; 7; 0; 0; 0; 0; + 1; 2; 6; 7; 0; 0; 0; 0; + 0; 1; 2; 6; 7; 0; 0; 0; + 3; 6; 7; 0; 0; 0; 0; 0; + 0; 3; 6; 7; 0; 0; 0; 0; + 1; 3; 6; 7; 0; 0; 0; 0; + 0; 1; 3; 6; 7; 0; 0; 0; + 2; 3; 6; 7; 0; 0; 0; 0; + 0; 2; 3; 6; 7; 0; 0; 0; + 1; 2; 3; 6; 7; 0; 0; 0; + 0; 1; 2; 3; 6; 7; 0; 0; + 4; 6; 7; 0; 0; 0; 0; 0; + 0; 4; 6; 7; 0; 0; 0; 0; + 1; 4; 6; 7; 0; 0; 0; 0; + 0; 1; 4; 6; 7; 0; 0; 0; + 2; 4; 6; 7; 0; 0; 0; 0; + 0; 2; 4; 6; 7; 0; 0; 0; + 1; 2; 4; 6; 7; 0; 0; 0; + 0; 1; 2; 4; 6; 7; 0; 0; + 3; 4; 6; 7; 0; 0; 0; 0; + 0; 3; 4; 6; 7; 0; 0; 0; + 1; 3; 4; 6; 7; 0; 0; 0; + 0; 1; 3; 4; 6; 7; 0; 0; + 2; 3; 4; 6; 7; 0; 0; 0; + 0; 2; 3; 4; 6; 7; 0; 0; + 1; 2; 3; 4; 6; 7; 0; 0; + 0; 1; 2; 3; 4; 6; 7; 0; + 5; 6; 7; 0; 0; 0; 0; 0; + 0; 5; 6; 7; 0; 0; 0; 0; + 1; 5; 6; 7; 0; 0; 0; 0; + 0; 1; 5; 6; 7; 0; 0; 0; + 2; 5; 6; 7; 0; 0; 0; 0; + 0; 2; 5; 6; 7; 0; 0; 0; + 1; 2; 5; 6; 7; 0; 0; 0; + 0; 1; 2; 5; 6; 7; 0; 0; + 3; 5; 6; 7; 0; 0; 0; 0; + 0; 3; 5; 6; 7; 0; 0; 0; + 1; 3; 5; 6; 7; 0; 0; 0; + 0; 1; 3; 5; 6; 7; 0; 0; + 2; 3; 5; 6; 7; 0; 0; 0; + 0; 2; 3; 5; 6; 7; 0; 0; + 1; 2; 3; 5; 6; 7; 0; 0; + 0; 1; 2; 3; 5; 6; 7; 0; + 4; 5; 6; 7; 0; 0; 0; 0; + 0; 4; 5; 6; 7; 0; 0; 0; + 1; 4; 5; 6; 7; 0; 0; 0; + 0; 1; 4; 5; 6; 7; 0; 0; + 2; 4; 5; 6; 7; 0; 0; 0; + 0; 2; 4; 5; 6; 7; 0; 0; + 1; 2; 4; 5; 6; 7; 0; 0; + 0; 1; 2; 4; 5; 6; 7; 0; + 3; 4; 5; 6; 7; 0; 0; 0; + 0; 3; 4; 5; 6; 7; 0; 0; + 1; 3; 4; 5; 6; 7; 0; 0; + 0; 1; 3; 4; 5; 6; 7; 0; + 2; 3; 4; 5; 6; 7; 0; 0; + 0; 2; 3; 4; 5; 6; 7; 0; + 1; 2; 3; 4; 5; 6; 7; 0; + 0; 1; 2; 3; 4; 5; 6; 7]` +;; diff --git a/x86/proofs/simulator.ml b/x86/proofs/simulator.ml index cb6c51d14..ac1f03f53 100755 --- a/x86/proofs/simulator.ml +++ b/x86/proofs/simulator.ml @@ -885,6 +885,27 @@ let iclasses = iclasses_regreg @ [0xc5; 0xf5; 0x73; 0xda; 0x63]; (* VPSRLDQ (%_% ymm1) (%_% ymm2) (Imm8 (word 99)) *) [0xc4; 0xc1; 0x79; 0xc5; 0xcc; 0x64]; (* VPEXTRW (% ecx) (%_% xmm12) (Imm8 (word 100)) *) [0xc5; 0x79; 0xc5; 0xfe; 0x52]; (* VPEXTRW (% r15d) (%_% xmm6) (Imm8 (word 82)) *) + [0xc5; 0xfc; 0x50; 0xc1]; (* vmovmskps eax, ymm1 *) + [0xc5; 0xfc; 0x50; 0xca]; (* vmovmskps ecx, ymm2 *) + [0xc5; 0xfc; 0x50; 0xd3]; (* vmovmskps edx, ymm3 *) + [0xc5; 0xfc; 0x50; 0xfc]; (* vmovmskps edi, ymm4 *) + [0xc4; 0x41; 0x7c; 0x50; 0xc1]; (* vmovmskps r8d, ymm9 *) + [0xc4; 0x41; 0x7c; 0x50; 0xd3]; (* vmovmskps r10d, ymm11 *) + [0xc4; 0x41; 0x7c; 0x50; 0xe5]; (* vmovmskps r12d, ymm13 *) + [0xc4; 0x41; 0x7c; 0x50; 0xf7]; (* vmovmskps r14d, ymm15 *) + [0xc4; 0xe2; 0x7d; 0x31; 0xc1]; (* vpmovzxbd ymm0, xmm1 *) + [0xc4; 0xe2; 0x7d; 0x31; 0xca]; (* vpmovzxbd ymm1, xmm2 *) + [0xc4; 0xe2; 0x7d; 0x31; 0xd3]; (* vpmovzxbd ymm2, xmm3 *) + [0xc4; 0xe2; 0x7d; 0x31; 0xdc]; (* vpmovzxbd ymm3, xmm4 *) + [0xc4; 0x42; 0x7d; 0x31; 0xc1]; (* vpmovzxbd ymm8, xmm9 *) + [0xc4; 0x42; 0x7d; 0x31; 0xd3]; (* vpmovzxbd ymm10, xmm11 *) + [0xc4; 0x42; 0x7d; 0x31; 0xe5]; (* vpmovzxbd ymm12, xmm13 *) + [0xc4; 0x42; 0x7d; 0x31; 0xf7]; (* vpmovzxbd ymm14, xmm15 *) + [0xc4; 0xe2; 0x79; 0x31; 0xc1]; (* vpmovzxbd xmm0, xmm1 *) + [0xc4; 0xe2; 0x79; 0x31; 0xca]; (* vpmovzxbd xmm1, xmm2 *) + [0xc4; 0x42; 0x79; 0x31; 0xc1]; (* vpmovzxbd xmm8, xmm9 *) + [0xc4; 0x42; 0x79; 0x31; 0xf7]; (* vpmovzxbd xmm14, xmm15 *) + [0xc5; 0xf8; 0x77]; (* vzeroupper *) [0xc4; 0xe3; 0x69; 0x44; 0xcb; 0x00]; (* VPCLMULQDQ (%_% xmm1) (%_% xmm2) (%_% xmm3) (Imm8 (word 0)) *) [0xc4; 0xe3; 0x69; 0x44; 0xcb; 0x01]; (* VPCLMULQDQ (%_% xmm1) (%_% xmm2) (%_% xmm3) (Imm8 (word 1)) *) [0xc4; 0xe3; 0x69; 0x44; 0xcb; 0x10]; (* VPCLMULQDQ (%_% xmm1) (%_% xmm2) (%_% xmm3) (Imm8 (word 16)) *) diff --git a/x86/proofs/specifications.txt b/x86/proofs/specifications.txt index 973ed9518..40b85cbe0 100644 --- a/x86/proofs/specifications.txt +++ b/x86/proofs/specifications.txt @@ -1390,6 +1390,10 @@ MLDSA_REDUCE_NOIBT_SUBROUTINE_CORRECT MLDSA_REDUCE_NOIBT_WINDOWS_SUBROUTINE_CORRECT MLDSA_REDUCE_SUBROUTINE_CORRECT MLDSA_REDUCE_WINDOWS_SUBROUTINE_CORRECT +MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_CORRECT +MLDSA_REJ_UNIFORM_NOIBT_WINDOWS_SUBROUTINE_CORRECT +MLDSA_REJ_UNIFORM_SUBROUTINE_CORRECT +MLDSA_REJ_UNIFORM_WINDOWS_SUBROUTINE_CORRECT MLKEM_BASEMUL_K2_NOIBT_SUBROUTINE_CORRECT MLKEM_BASEMUL_K2_NOIBT_WINDOWS_SUBROUTINE_CORRECT MLKEM_BASEMUL_K2_SUBROUTINE_CORRECT diff --git a/x86/proofs/subroutine_signatures.ml b/x86/proofs/subroutine_signatures.ml index 8ce7f72b6..35f26d22c 100644 --- a/x86/proofs/subroutine_signatures.ml +++ b/x86/proofs/subroutine_signatures.ml @@ -4862,6 +4862,24 @@ let subroutine_signatures = [ ]) ); +("mldsa_rej_uniform", + ([(*args*) + ("r", "int32_t[static 256]", (*is const?*)"false"); + ("buf", "uint8_t[static 840]", (*is const?*)"true"); + ("table", "uint64_t[static 256]", (*is const?*)"true"); + ], + "uint32_t", + [(* input buffers *) + ("buf", "840"(* num elems *), 1(* elem bytesize *)); + ("table", "256"(* num elems *), 8(* elem bytesize *)); + ], + [(* output buffers *) + ("r", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* temporary buffers *) + ]) +); + ("mlkem_basemul_k2", ([(*args*) ("r", "int16_t[static 256]", (*is const?*)"false"); diff --git a/x86/proofs/x86.ml b/x86/proofs/x86.ml index d281d11e0..927076363 100644 --- a/x86/proofs/x86.ml +++ b/x86/proofs/x86.ml @@ -1019,6 +1019,25 @@ let x86_DEC = new_definition let x86_ENDBR64 = new_definition `x86_ENDBR64 (s:x86state) = \s'. s = s'`;; +let x86_VZEROUPPER = new_definition + `x86_VZEROUPPER (s:x86state) = + (YMM0 := word_zx(word_subword (read YMM0 s) (0,128):int128) ,, + YMM1 := word_zx(word_subword (read YMM1 s) (0,128):int128) ,, + YMM2 := word_zx(word_subword (read YMM2 s) (0,128):int128) ,, + YMM3 := word_zx(word_subword (read YMM3 s) (0,128):int128) ,, + YMM4 := word_zx(word_subword (read YMM4 s) (0,128):int128) ,, + YMM5 := word_zx(word_subword (read YMM5 s) (0,128):int128) ,, + YMM6 := word_zx(word_subword (read YMM6 s) (0,128):int128) ,, + YMM7 := word_zx(word_subword (read YMM7 s) (0,128):int128) ,, + YMM8 := word_zx(word_subword (read YMM8 s) (0,128):int128) ,, + YMM9 := word_zx(word_subword (read YMM9 s) (0,128):int128) ,, + YMM10 := word_zx(word_subword (read YMM10 s) (0,128):int128) ,, + YMM11 := word_zx(word_subword (read YMM11 s) (0,128):int128) ,, + YMM12 := word_zx(word_subword (read YMM12 s) (0,128):int128) ,, + YMM13 := word_zx(word_subword (read YMM13 s) (0,128):int128) ,, + YMM14 := word_zx(word_subword (read YMM14 s) (0,128):int128) ,, + YMM15 := word_zx(word_subword (read YMM15 s) (0,128):int128)) s`;; + (*** There are really four different multiplies here. *** *** 1. x86_IMUL: a signed multiply with the same length for operands @@ -1357,6 +1376,20 @@ let x86_PMOVMSKB = new_definition let res:int16 = usimd16 (\x. if bit 7 x then word 1 else word 0) x in (dest := word_zx res:N word) s`;; +let x86_VMOVMSKPS = new_definition + `x86_VMOVMSKPS dest src (s:x86state) = + let x:int256 = read src s in + let res:byte = word( + bitval(bit 31 (word_subword x (0,32):int32)) + + 2 * bitval(bit 31 (word_subword x (32,32):int32)) + + 4 * bitval(bit 31 (word_subword x (64,32):int32)) + + 8 * bitval(bit 31 (word_subword x (96,32):int32)) + + 16 * bitval(bit 31 (word_subword x (128,32):int32)) + + 32 * bitval(bit 31 (word_subword x (160,32):int32)) + + 64 * bitval(bit 31 (word_subword x (192,32):int32)) + + 128 * bitval(bit 31 (word_subword x (224,32):int32))) in + (dest := word_zx res:int32) s`;; + (*** Push and pop are a bit odd in several ways. First of all, there is ***) (*** an implicit memory operand so this doesn't have quite the same ***) (*** "shallowness": we refer to the memory component explicitly. And we ***) @@ -1915,6 +1948,17 @@ let x86_VPMULDQ = new_definition let res:(128)word = simd2 f (word_zx x) (word_zx y) in (dest := (word_zx res):N word) s`;; +let x86_VPMOVZXBD = new_definition + `x86_VPMOVZXBD dest src (s:x86state) = + let (x:M word) = read src s in + let f = \(b:byte). word_zx b:int32 in + if dimindex(:N) = 256 then + let res:(256)word = usimd8 f (word_zx x:int64) in + (dest := (word_zx res):N word) s + else + let res:(128)word = usimd4 f (word_zx x:int32) in + (dest := (word_zx res):N word) s`;; + let x86_VPMULHRSW = new_definition `x86_VPMULHRSW dest src1 src2 (s:x86state) = let (x:N word) = read src1 s @@ -3247,6 +3291,9 @@ let x86_execute = define (\s. (match operand_size dest with 256 -> x86_VMOVDQU (OPERAND256 dest s) (OPERAND256 src s) s | 128 -> x86_VMOVDQU (OPERAND128 dest s) (OPERAND128 src s) s))) s + | VMOVMSKPS dest src -> + (add_load_event src s ,, add_store_event dest s ,, + (\s. x86_VMOVMSKPS (OPERAND32 dest s) (OPERAND256 src s) s)) s | VMOVSHDUP dest src -> (add_load_event src s ,, add_store_event dest s ,, (\s. (match operand_size dest with @@ -3418,6 +3465,15 @@ let x86_execute = define (OPERAND256 src2 s) | 128 -> x86_VPMADDWD (OPERAND128 dest s) (OPERAND128 src1 s) (OPERAND128 src2 s)) s)) s + | VPMOVZXBD dest src -> + (add_load_event src s ,, add_store_event dest s ,, + (\s. (match operand_size dest with + 256 -> (match operand_size src with + 128 -> x86_VPMOVZXBD (OPERAND256 dest s) (OPERAND128 src s) + | 64 -> x86_VPMOVZXBD (OPERAND256 dest s) (OPERAND64 src s)) + | 128 -> (match operand_size src with + 128 -> x86_VPMOVZXBD (OPERAND128 dest s) (OPERAND128 src s) + | 32 -> x86_VPMOVZXBD (OPERAND128 dest s) (OPERAND32 src s))) s)) s | VPMULDQ dest src1 src2 -> (add_load_event src1 s ,, add_load_event src2 s ,, add_store_event dest s ,, @@ -3634,6 +3690,8 @@ let x86_execute = define (\s. (match operand_size dest with 256 -> x86_VPUNPCKLQDQ (OPERAND256 dest s) (OPERAND256 src1 s) (OPERAND256 src2 s) | 128 -> x86_VPUNPCKLQDQ (OPERAND128 dest s) (OPERAND128 src1 s) (OPERAND128 src2 s)) s)) s + | VZEROUPPER -> + x86_VZEROUPPER s | XCHG dest src -> (add_load_event src s ,, add_load_event dest s ,, add_store_event dest s ,, add_store_event src s ,, @@ -4465,6 +4523,9 @@ let x86_VPBLENDVB_ALT = CONV_RULE (DEPTH_CONV DIMINDEX_CONV)) x86_VPBLENDVB;; let x86_VPMADDUBSW_ALT = EXPAND_SIMD_RULE x86_VPMADDUBSW;; let x86_VPMADDWD_ALT = EXPAND_SIMD_RULE x86_VPMADDWD;; +let x86_VMOVMSKPS_ALT = x86_VMOVMSKPS;; +let x86_VPMOVZXBD_ALT = EXPAND_SIMD_RULE x86_VPMOVZXBD;; +let x86_VZEROUPPER_ALT = x86_VZEROUPPER;; let x86_VPMULDQ_ALT = EXPAND_SIMD_RULE x86_VPMULDQ;; let x86_VPMULHRSW_ALT = EXPAND_SIMD_RULE x86_VPMULHRSW;; let x86_VPMULHW_ALT = EXPAND_SIMD_RULE x86_VPMULHW;; @@ -4518,6 +4579,7 @@ let X86_OPERATION_CLAUSES = x86_VPACKUSWB_ALT; x86_VPBLENDVB_ALT; x86_VPBLENDD_ALT; x86_VPBLENDW_ALT; x86_VPCLMULQDQ_ALT; x86_VPERMD_ALT; x86_VPERMQ_ALT; x86_VPSHUFB_ALT; x86_VPUNPCKLQDQ_ALT; x86_VPUNPCKHQDQ_ALT; x86_VPBROADCASTQ_ALT; x86_VPERM2I128_ALT; + x86_VMOVMSKPS_ALT; x86_VPMOVZXBD_ALT; x86_VZEROUPPER_ALT; (*** 32-bit backups since the ALT forms are 64-bit only ***) INST_TYPE[`:32`,`:N`] x86_ADC; INST_TYPE[`:32`,`:N`] x86_ADCX; diff --git a/x86_att/Makefile b/x86_att/Makefile index 09d2599a2..e417c4196 100644 --- a/x86_att/Makefile +++ b/x86_att/Makefile @@ -146,6 +146,7 @@ OBJ = curve25519/bignum_add_p25519.o \ mldsa/mldsa_pointwise_acc_l5.o \ mldsa/mldsa_pointwise_acc_l7.o \ mldsa/mldsa_reduce.o \ + mldsa/mldsa_rej_uniform.o \ mlkem/mlkem_basemul_k2.o \ mlkem/mlkem_basemul_k3.o \ mlkem/mlkem_basemul_k4.o \ diff --git a/x86_att/attrofy.sed b/x86_att/attrofy.sed index d00b05493..5518a5198 100644 --- a/x86_att/attrofy.sed +++ b/x86_att/attrofy.sed @@ -17,6 +17,14 @@ s/\.intel_syntax *noprefix// s/_internal_s2n_bignum_x86/_internal_s2n_bignum_x86_att/ +# Swap zero-extending loads to AT&T size-suffixed form BEFORE memory-operand +# reshaping. We rewrite the mnemonic and strip the WORD/BYTE PTR hint so the +# remaining rules just translate the memory operand and swap register order. +s/ movzx +([a-z][a-z_0-9]*d?), *WORD PTR/ movzwl \1,/g +s/ movzx +([a-z][a-z_0-9]*d?), *word ptr/ movzwl \1,/g +s/ movzx +([a-z][a-z_0-9]*d?), *BYTE PTR/ movzbl \1,/g +s/ movzx +([a-z][a-z_0-9]*d?), *byte ptr/ movzbl \1,/g + # Don't make any transforms on lines with most argument-taking macros # We need to be more careful with those taking ymm register arguments @@ -99,7 +107,18 @@ s/([[(,.;: ])([xyz]mm[0-9]*)/\1\%\2/g # Add explicit sizes to instructions +s/YMMWORD PTR//g +s/ymmword ptr//g +s/XMMWORD PTR//g +s/xmmword ptr//g s/QWORD PTR//g +s/qword ptr//g +s/DWORD PTR//g +s/dword ptr//g +s/WORD PTR//g +s/word ptr//g +s/BYTE PTR//g +s/byte ptr//g s/ adc / adcq /g s/ adcx / adcxq /g diff --git a/x86_att/mldsa/mldsa_rej_uniform.S b/x86_att/mldsa/mldsa_rej_uniform.S new file mode 100644 index 000000000..deaaf0e51 --- /dev/null +++ b/x86_att/mldsa/mldsa_rej_uniform.S @@ -0,0 +1,147 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +// ---------------------------------------------------------------------------- +// Uniform rejection sampling for ML-DSA +// Input buf[840] (uint8_t); output r[256] (int32_t); table[2048] (uint64_t) +// Returns: number of valid coefficients in r (at most 256) +// +// This function implements the rejection-sampling loop for ML-DSA, extracting +// 23-bit coefficients from packed 24-bit input bytes and keeping only those +// strictly less than q = 8380417. A main AVX2 loop processes 24 bytes (8 +// coefficients) per iteration using VPERMQ+VPSHUFB extraction, VPAND masking, +// VPSUBD+VMOVMSKPS rejection, and VPERMD+table compaction. A scalar tail +// handles any remaining bytes after the main loop exits. +// +// This implementation is derived from the public domain AVX2 Dilithium +// implementation from CRYSTALS-Dilithium optimized AVX2 implementation by +// Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé +// (https://github.com/pq-crystals/dilithium/tree/master/avx2) +// +// extern uint32_t mldsa_rej_uniform +// (int32_t r[static 256], +// const uint8_t buf[static 840], +// const uint64_t table[static 256]); +// +// Standard x86-64 ABI: RDI = r, RSI = buf, RDX = table +// Microsoft x64 ABI: RCX = r, RDX = buf, R8 = table +// ---------------------------------------------------------------------------- + +#include "_internal_s2n_bignum_x86_att.h" + + + S2N_BN_SYM_VISIBILITY_DIRECTIVE(mldsa_rej_uniform) + S2N_BN_FUNCTION_TYPE_DIRECTIVE(mldsa_rej_uniform) + S2N_BN_SYM_PRIVACY_DIRECTIVE(mldsa_rej_uniform) + .text + +S2N_BN_SYMBOL(mldsa_rej_uniform): + + _CET_ENDBR + +#if WINDOWS_ABI + pushq %rdi + pushq %rsi + subq $160, %rsp + movdqu %xmm6, 0(%rsp) + movdqu %xmm7, 16(%rsp) + movdqu %xmm8, 32(%rsp) + movdqu %xmm9, 48(%rsp) + movdqu %xmm10, 64(%rsp) + movdqu %xmm11, 80(%rsp) + movdqu %xmm12, 96(%rsp) + movdqu %xmm13, 112(%rsp) + movdqu %xmm14, 128(%rsp) + movdqu %xmm15, 144(%rsp) + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx +#endif + + // Shuffle mask: expand 24 bytes (8 x 3-byte coefficients) into + // 8 x 4-byte lanes (with a zero high byte in each). + movq $0xff050403ff020100, %r10 + vmovq %r10, %xmm0 + movq $0xff0b0a09ff080706, %r10 + vpinsrq $0x1, %r10, %xmm0, %xmm0 + movq $0xff090807ff060504, %r10 + vmovq %r10, %xmm3 + movq $0xff0f0e0dff0c0b0a, %r10 + vpinsrq $0x1, %r10, %xmm3, %xmm3 + vinserti128 $0x1, %xmm3, %ymm0, %ymm0 + + // Mask 0x7fffff in all 8 lanes (keep low 23 bits) + movl $0x7fffff, %r8d + vmovd %r8d, %xmm1 + vpbroadcastd %xmm1, %ymm1 + + // Threshold q = 0x7fe001 in all 8 lanes + movl $0x7fe001, %r8d + vmovd %r8d, %xmm2 + vpbroadcastd %xmm2, %ymm2 + + // %rax = accepted count, %rcx = byte offset into buf + xorl %eax, %eax + xorl %ecx, %ecx + +Lmldsa_rej_uniform_loop: + // Exit to scalar tail if we have 248 or more accepted (next 8 might + // overflow) or if byte offset is past 808 (would read past buf+840-24). + cmpl $0xf8, %eax + ja Lmldsa_rej_uniform_scalar + cmpl $0x328, %ecx + ja Lmldsa_rej_uniform_scalar + + vmovdqu (%rsi,%rcx), %ymm3 + addl $0x18, %ecx + vpermq $0x94, %ymm3, %ymm3 + vpshufb %ymm0, %ymm3, %ymm3 + vpand %ymm1, %ymm3, %ymm3 + vpsubd %ymm2, %ymm3, %ymm4 + vmovmskps %ymm4, %r8d + popcnt %r8d, %r9d + vmovq (%rdx,%r8,8), %xmm4 + vpmovzxbd %xmm4, %ymm4 + vpermd %ymm3, %ymm4, %ymm3 + vmovdqu %ymm3, (%rdi,%rax,4) + addl %r9d, %eax + jmp Lmldsa_rej_uniform_loop + +Lmldsa_rej_uniform_scalar: + cmpl $0x100, %eax + jae Lmldsa_rej_uniform_done + cmpl $0x345, %ecx + ja Lmldsa_rej_uniform_done + movzwl (%rsi,%rcx), %r8d + movzbl 2(%rsi,%rcx,1), %r9d + shll $0x10, %r9d + orl %r9d, %r8d + andl $0x7fffff, %r8d + addl $0x3, %ecx + cmpl $0x7fe001, %r8d + jae Lmldsa_rej_uniform_scalar + movl %r8d, (%rdi,%rax,4) + addl $0x1, %eax + jmp Lmldsa_rej_uniform_scalar + +Lmldsa_rej_uniform_done: +#if WINDOWS_ABI + movdqu 0(%rsp), %xmm6 + movdqu 16(%rsp), %xmm7 + movdqu 32(%rsp), %xmm8 + movdqu 48(%rsp), %xmm9 + movdqu 64(%rsp), %xmm10 + movdqu 80(%rsp), %xmm11 + movdqu 96(%rsp), %xmm12 + movdqu 112(%rsp), %xmm13 + movdqu 128(%rsp), %xmm14 + movdqu 144(%rsp), %xmm15 + addq $160, %rsp + popq %rsi + popq %rdi +#endif + ret + +#if defined(__linux__) && defined(__ELF__) + .section .note.GNU-stack,"",%progbits +#endif