Skip to content

Commit 4339c2f

Browse files
committed
feat: solver-aware canonicalization for Z3 AVX intrinsics
AVX instructions like vpermd only use a subset of bits from their control operands (e.g., 3 bits out of 32 per lane). Z3 treats the unused bits as free variables, causing the solver to enumerate functionally-identical solutions that differ only in don't-care bits. Every z3_avx function that has don't-care bits now requires a Solver parameter and adds constraints zeroing unused bits: - _generic_permutexvar: index bits per lane - _generic_permutex2var: offset + source selector bits per lane - _generic_permutevar: control bits per lane (ps: 2, pd: 1) - _generic_blendv: sign bit only per lane - _generic_blend: element count bits of imm8 - _shuffle_pd_generic: lane count * 2 bits of imm8 - _permute_pd_generic: bits [1:0] of imm8 - _select4_128b: bit 2 unused per control nibble - _generic_alignr: shift bits of imm8 The solver from synthesize_gadget_with_symbolic flows through _evaluate -> _dispatch_intrinsic_by_signature -> z3_avx functions. Uses inspect.signature (cached via lru_cache) to detect which intrinsics accept the solver parameter.
1 parent 9d5f18d commit 4339c2f

File tree

4 files changed

+669
-249
lines changed

4 files changed

+669
-249
lines changed

vxsort/smallsort/codegen/src/gadget_synthesizer.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from __future__ import annotations
88

9+
import functools
10+
import inspect
911
import io
1012
import os
1113
import tarfile
@@ -237,20 +239,33 @@ def _match_dispatch_rule(arg_keys: frozenset[str]) -> _IntrinsicDispatchRule | N
237239
return None
238240

239241

240-
def _dispatch_intrinsic_fallback(intrinsic, args: dict):
242+
@functools.lru_cache(maxsize=256)
243+
def _accepts_solver(intrinsic) -> bool:
244+
"""Return True if *intrinsic* has a ``solver`` keyword parameter."""
245+
try:
246+
return "solver" in inspect.signature(intrinsic).parameters
247+
except (ValueError, TypeError):
248+
return False
249+
250+
251+
def _dispatch_intrinsic_fallback(intrinsic, args: dict, solver: Solver):
241252
"""Fallback positional dispatch that preserves incoming argument order."""
242253
arg_order = tuple(args.keys())
254+
if _accepts_solver(intrinsic):
255+
return intrinsic(*(args[key] for key in arg_order), solver=solver)
243256
return intrinsic(*(args[key] for key in arg_order))
244257

245258

246-
def _dispatch_intrinsic_by_signature(intrinsic, args: dict):
259+
def _dispatch_intrinsic_by_signature(intrinsic, args: dict, solver: Solver):
247260
"""Call an intrinsic using the first matching dispatch rule.
248261
249262
Falls back to legacy positional call order when no rule matches.
250263
"""
251264
rule = _match_dispatch_rule(frozenset(args.keys()))
252265
if rule is None:
253-
return _dispatch_intrinsic_fallback(intrinsic, args)
266+
return _dispatch_intrinsic_fallback(intrinsic, args, solver=solver)
267+
if _accepts_solver(intrinsic):
268+
return intrinsic(*(args[key] for key in rule.call_order), solver=solver)
254269
return intrinsic(*(args[key] for key in rule.call_order))
255270

256271

@@ -539,10 +554,22 @@ def synthesize_gadget_with_symbolic(
539554
eval_cache = {}
540555

541556
top_output = self._evaluate(
542-
graph.top, registers, ctx, symbolic_vars, mux_constraints, eval_cache
557+
graph.top,
558+
registers,
559+
ctx,
560+
symbolic_vars,
561+
mux_constraints,
562+
eval_cache,
563+
solver=solver,
543564
)
544565
bottom_output = self._evaluate(
545-
graph.bottom, registers, ctx, symbolic_vars, mux_constraints, eval_cache
566+
graph.bottom,
567+
registers,
568+
ctx,
569+
symbolic_vars,
570+
mux_constraints,
571+
eval_cache,
572+
solver=solver,
546573
)
547574

548575
# Identity side: pass through the input register
@@ -720,6 +747,7 @@ def _evaluate(
720747
symbolic_vars: dict,
721748
mux_constraints: list,
722749
cache: dict,
750+
solver: Solver,
723751
):
724752
"""Recursively evaluate a graph node to a Z3 expression.
725753
@@ -751,10 +779,24 @@ def _evaluate(
751779

752780
elif isinstance(node, Mux):
753781
sel = self._evaluate(
754-
node.select, registers, ctx, symbolic_vars, mux_constraints, cache
782+
node.select,
783+
registers,
784+
ctx,
785+
symbolic_vars,
786+
mux_constraints,
787+
cache,
788+
solver=solver,
755789
)
756790
sources = [
757-
self._evaluate(s, registers, ctx, symbolic_vars, mux_constraints, cache)
791+
self._evaluate(
792+
s,
793+
registers,
794+
ctx,
795+
symbolic_vars,
796+
mux_constraints,
797+
cache,
798+
solver=solver,
799+
)
758800
for s in node.sources
759801
]
760802
# Range constraint
@@ -776,9 +818,15 @@ def _evaluate(
776818
args = {}
777819
for key, operand in node.operands.items():
778820
args[key] = self._evaluate(
779-
operand, registers, ctx, symbolic_vars, mux_constraints, cache
821+
operand,
822+
registers,
823+
ctx,
824+
symbolic_vars,
825+
mux_constraints,
826+
cache,
827+
solver=solver,
780828
)
781-
result = _dispatch_intrinsic_by_signature(intrinsic_fn, args)
829+
result = _dispatch_intrinsic_by_signature(intrinsic_fn, args, solver=solver)
782830
# Apply combined mux pruning for any Mux children
783831
self._apply_mux_pruning(node, symbolic_vars, mux_constraints)
784832
else:
@@ -1092,7 +1140,9 @@ def _apply_concrete_instructions(
10921140
else:
10931141
args[key] = value
10941142

1095-
current_reg = _dispatch_intrinsic_by_signature(intrinsic, args)
1143+
current_reg = _dispatch_intrinsic_by_signature(
1144+
intrinsic, args, solver=Solver()
1145+
)
10961146
results.append(current_reg)
10971147

10981148
return current_reg, []

0 commit comments

Comments
 (0)