Skip to content

perf(patch): Level-2 (2a) trunk re-entry prototype + proof#218

Closed
RhizoNymph wants to merge 5 commits into
feat/activation-patchingfrom
perf/patch-trunk-reuse
Closed

perf(patch): Level-2 (2a) trunk re-entry prototype + proof#218
RhizoNymph wants to merge 5 commits into
feat/activation-patchingfrom
perf/patch-trunk-reuse

Conversation

@RhizoNymph

@RhizoNymph RhizoNymph commented Jul 1, 2026

Copy link
Copy Markdown
Owner

Level-2 activation-patching sweeps: skip the layers below a patch site by re-entering the forward mid-stack with the cached (unpatched) corrupt trunk residual, instead of re-running the whole stack per cell. Level 1 (the shipped /v1/patch_sweep) recomputes all layers for every cell; Level-2 ("2a") recomputes only layers ≥ L.

This PR contains both the proof and the full continuous-batching + endpoint integration.

Mechanism (qwen2.py, gated)

  • _PATCH_2A_ENTRY = (start_layer, hidden_or_None). When set, Qwen2Model.forward enters the stack at start_layer using the merged residual stream as input (residual=None), reusing the existing islice(layers, start_layer, end_layer) loop. Zero overhead when unset.
  • Feeding (hidden=merged_stream, residual=None) is bit-identical to the normal path because input_layernorm(h, r) only depends on h + r. So the trunk needed is just post_block[L-1], which the capture path already produces — no new capture point, no deep forward surgery.

Continuous-batching integration

  • SamplingParams.patch_2a = {entry_layer, trunk_run} → the v2 GPU runner resolves the trunk from the existing PatchSourceStore and builds a dense entry (inputs_embeds) via one batched H2D transfer, sets a per-step start_layer, and lets the normal patch plane apply the clean source on top. Guarded to eager (a FULL cudagraph ignores model_inputs); any mixed batch / missing row degrades to a correct Level-1 forward.
  • /v1/patch_sweep gains mode="2a" + trunk_run: captures the corrupt trunk (only the needed post_block[L-1] layers) or reuses a pre-captured one, and serializes per-layer groups so each in-flight batch is homogeneous in entry_layer.

Validated (GPU, Qwen3-0.6B)

  • Correctness: tests/patch_2a_proof.py — a 2a cell reproduces the Level-1 patched distribution bit-exact (0.000e+00) across L=2→26. Endpoint (tests/patch_2a_endpoint_validate.py): 2a grid == Level-1 grid, argmax matches (small logprob deltas are batch-nondeterminism).
  • Speedup (tests/patch_2a_amortized.py, long 2349-tok prompt, 160-cell sweep, amortized trunk): 2a +27% faster than Level-1 (6931 ms vs 9490 ms; +15% including the one-time trunk capture).

Scope / usage

  • The win requires the regime 2a targets: long prompt + large per-layer groups + an amortized trunk (capture once, reuse via trunk_run). For tiny sweeps / per-sweep trunk recapture, 2a is slower — dispatch small→Level-1.
  • Run the server with --enforce-eager --no-enable-prefix-caching (2a recomputes all positions; the endpoint rejects mode=2a under prefix caching).
  • Remaining hardening: per-request prefix-floor-0 (to drop the APC-off requirement), entry-build vectorization, and adaptive dispatch.

Builds on #212 (activation patching).

@RhizoNymph

Copy link
Copy Markdown
Owner Author

Added adaptive dispatch (commit 22ea6e7): mode="auto" on /v1/patch_sweep resolves per sweep — 2a only when the prompt is long AND positions-per-layer is large (and prefix caching is off), else level1. Thresholds are tunable (auto_min_prompt_tokens, auto_min_positions; defaults 512/16 for a small model, lower for larger models). The chosen strategy is returned as mode_used. Decision is a pure helper dispatch_mode() with 6 unit tests; live-checked: short sweep→level1, long 20-position sweep→2a.

@RhizoNymph

Copy link
Copy Markdown
Owner Author

Dropped the --no-enable-prefix-caching requirement (commit 62f2739). Each 2a cell now gets a unique cache_salt (2a-{trunk_run}-{layer}-{pos}), which makes its block hashes unique — so it finds no prefix hit and recomputes every position (the floor-0 2a needs to rebuild layers≥L from the injected trunk), and its patched KV can neither be read by nor poison other requests. No engine-internals changes; reuses the existing cache_salt extra-hash-key path.

Validated with prefix caching ON (server started without --no-enable-prefix-caching): 2a grid == Level-1 (argmax matches), amortized win holds at +27% (6940ms vs 9442ms), and mode=auto now selects 2a for long sweeps under APC (the apc-off gate was removed from dispatch_mode). Server now just needs --enforce-eager.

@RhizoNymph

Copy link
Copy Markdown
Owner Author

Important negative result (commit 6caf19a). Two things:

  1. Fixed the server-wide --enforce-eager: 2a's per-step start_layer is incompatible with a captured cudagraph (it's ignored) / torch.compile (it's baked), so 2a steps must run eager+uncompiled. Now confined to only 2a steps via patch_2a_pending → skip_compiled (same need_eager path capture uses); all other traffic keeps cudagraphs.

  2. But that exposed that 2a provides no net benefit over Level-1 on a normal (APC-enabled) server — measured +1%/-0% amortized across stable reps. Root cause is fundamental: Level-1 under prefix caching already gets position-windowing for free (the patch floor recomputes only p..end, reusing prefix KV), while 2a cannot use it — it must recompute every position to rebuild layers ≥ L from the injected trunk. So 2a just trades Level-1's position-axis saving for the layer-axis saving. The math: Level-1 avg cell ≈ (n/2)·28 layers = 14n; 2a avg cell ≈ n·(28−avgL) ≈ 15n — equal. The earlier +27% was an artifact of measuring against a Level-1 crippled on both axes (--no-enable-prefix-caching + forced eager).

Recommendation: don't ship 2a. Level-1 (#212) with default prefix caching is already near-optimal; the layer-skip and APC's position-windowing are the same-magnitude savings on orthogonal axes. This PR stands as a documented, GPU-validated negative result.

@RhizoNymph RhizoNymph closed this Jul 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant