perf(patch): Level-2 (2a) trunk re-entry prototype + proof#218
perf(patch): Level-2 (2a) trunk re-entry prototype + proof#218RhizoNymph wants to merge 5 commits into
Conversation
|
Added adaptive dispatch (commit 22ea6e7): |
…oor-0 + write isolation)
|
Dropped the Validated with prefix caching ON (server started without |
…f server-wide --enforce-eager
|
Important negative result (commit 6caf19a). Two things:
Recommendation: don't ship 2a. Level-1 (#212) with default prefix caching is already near-optimal; the layer-skip and APC's position-windowing are the same-magnitude savings on orthogonal axes. This PR stands as a documented, GPU-validated negative result. |
Level-2 activation-patching sweeps: skip the layers below a patch site by re-entering the forward mid-stack with the cached (unpatched) corrupt trunk residual, instead of re-running the whole stack per cell. Level 1 (the shipped
/v1/patch_sweep) recomputes all layers for every cell; Level-2 ("2a") recomputes only layers ≥ L.This PR contains both the proof and the full continuous-batching + endpoint integration.
Mechanism (
qwen2.py, gated)_PATCH_2A_ENTRY = (start_layer, hidden_or_None). When set,Qwen2Model.forwardenters the stack atstart_layerusing the merged residual stream as input (residual=None), reusing the existingislice(layers, start_layer, end_layer)loop. Zero overhead when unset.(hidden=merged_stream, residual=None)is bit-identical to the normal path becauseinput_layernorm(h, r)only depends onh + r. So the trunk needed is justpost_block[L-1], which the capture path already produces — no new capture point, no deep forward surgery.Continuous-batching integration
SamplingParams.patch_2a = {entry_layer, trunk_run}→ the v2 GPU runner resolves the trunk from the existingPatchSourceStoreand builds a dense entry (inputs_embeds) via one batched H2D transfer, sets a per-stepstart_layer, and lets the normal patch plane apply the clean source on top. Guarded to eager (a FULL cudagraph ignoresmodel_inputs); any mixed batch / missing row degrades to a correct Level-1 forward./v1/patch_sweepgainsmode="2a"+trunk_run: captures the corrupt trunk (only the neededpost_block[L-1]layers) or reuses a pre-captured one, and serializes per-layer groups so each in-flight batch is homogeneous inentry_layer.Validated (GPU, Qwen3-0.6B)
tests/patch_2a_proof.py— a 2a cell reproduces the Level-1 patched distribution bit-exact (0.000e+00) across L=2→26. Endpoint (tests/patch_2a_endpoint_validate.py): 2a grid == Level-1 grid, argmax matches (small logprob deltas are batch-nondeterminism).tests/patch_2a_amortized.py, long 2349-tok prompt, 160-cell sweep, amortized trunk): 2a +27% faster than Level-1 (6931 ms vs 9490 ms; +15% including the one-time trunk capture).Scope / usage
trunk_run). For tiny sweeps / per-sweep trunk recapture, 2a is slower — dispatch small→Level-1.--enforce-eager --no-enable-prefix-caching(2a recomputes all positions; the endpoint rejectsmode=2aunder prefix caching).Builds on #212 (activation patching).